From d20366982d6e1ccfc2db36b01bae82aa8db3b7cb Mon Sep 17 00:00:00 2001 From: george ohashi Date: Fri, 12 Apr 2024 16:45:49 +0000 Subject: [PATCH 1/6] draft --- bin/quant.py | 55 +++ .../quantization/lifecycle/__init__.py | 22 + .../quantization/lifecycle/calirbation.py | 44 ++ .../quantization/lifecycle/forward.py | 155 +++++++ .../quantization/lifecycle/frozen.py | 37 ++ .../quantization/lifecycle/initialize.py | 77 ++++ .../quantization/lifecycle/status.py | 26 ++ .../quantization/observers/__init__.py | 19 + .../quantization/observers/base.py | 72 ++++ .../quantization/observers/memoryless.py | 0 .../quantization/observers/min_max.py | 80 ++++ .../quantization/utils/quantization_scheme.py | 391 ++++++++++++++++++ 12 files changed, 978 insertions(+) create mode 100644 bin/quant.py create mode 100644 src/sparsetensors/quantization/lifecycle/__init__.py create mode 100644 src/sparsetensors/quantization/lifecycle/calirbation.py create mode 100644 src/sparsetensors/quantization/lifecycle/forward.py create mode 100644 src/sparsetensors/quantization/lifecycle/frozen.py create mode 100644 src/sparsetensors/quantization/lifecycle/initialize.py create mode 100644 src/sparsetensors/quantization/lifecycle/status.py create mode 100644 src/sparsetensors/quantization/observers/__init__.py create mode 100644 src/sparsetensors/quantization/observers/base.py create mode 100644 src/sparsetensors/quantization/observers/memoryless.py create mode 100644 src/sparsetensors/quantization/observers/min_max.py create mode 100644 src/sparsetensors/quantization/utils/quantization_scheme.py diff --git a/bin/quant.py b/bin/quant.py new file mode 100644 index 00000000..3d33191f --- /dev/null +++ b/bin/quant.py @@ -0,0 +1,55 @@ +import torch +from torch.nn import Linear +# from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationScheme, QuantizationArgs +from sparsetensors.quantization.quant_args import QuantizationArgs +from sparsetensors.quantization.quant_scheme import QuantizationScheme +from sparseml.modifiers.quantization.lifecycle.initialize import initialize_module_for_quantization +from sparseml.modifiers.quantization.lifecycle.calibration import set_module_for_calibration +from sparseml.modifiers.quantization.lifecycle.frozen import freeze_module_quantization +num_bits = 8 + +scheme = QuantizationScheme( + input_acivations=QuantizationArgs(num_bits=num_bits, symmetric=False), + weights=QuantizationArgs(num_bits=num_bits, symmetric=True), + output_activations=None, +) + +layer = Linear(4, 4) +print(layer) +print(dict(layer.named_parameters())) + + +initialize_module_for_quantization(layer, scheme) +print(layer) # should see observer under layer now +print(0) +print(dict(layer.named_parameters())) # should see empty tensors for scale and zero point now +print(1) + + +set_module_for_calibration(layer) +# do a calibration step +layer(torch.randn(4,4)) +print(dict(layer.named_parameters())) # scale and zero point should have updated values +print(2) +for _ in range(10): + layer(torch.randn(4,4)) +print(dict(layer.named_parameters())) # scale and zero point should have updated values again since we did another pass + +print(3) +breakpoint() + + +freeze_module_quantization(layer) +for _ in range(10): + # do more forward passes but show args are frozen + layer(torch.random.randn(4,4)) +print(dict(layer.named_parameters())) # scale and zero point should not be updated now + + +# missing + +# correctness +# quantizing an entire model + + + diff --git a/src/sparsetensors/quantization/lifecycle/__init__.py b/src/sparsetensors/quantization/lifecycle/__init__.py new file mode 100644 index 00000000..d90b28a9 --- /dev/null +++ b/src/sparsetensors/quantization/lifecycle/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa + +from .calibration import * +from .forward import * +from .frozen import * +from .initialize import * +from .status import * +from .initialize import * diff --git a/src/sparsetensors/quantization/lifecycle/calirbation.py b/src/sparsetensors/quantization/lifecycle/calirbation.py new file mode 100644 index 00000000..a4f4dfea --- /dev/null +++ b/src/sparsetensors/quantization/lifecycle/calirbation.py @@ -0,0 +1,44 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +from torch.nn import Module + +from sparseml.modifiers.quantization.lifecycle.status import QuantizationStatus + + +__all__ = [ + "set_module_for_calibration", +] + + +_LOGGER = logging.getLogger(__name__) + + +def set_module_for_calibration(module: Module): + if not getattr(module, "quantization_scheme", None): + # no quantization scheme nothing to do + return + status = getattr(module, "quantization_status", None) + if not status or status != QuantizationStatus.INITIALIZED: + raise _LOGGER.warning( + f"Attempting set module with status {status} to calibration mode. " + f"but status is not {QuantizationStatus.INITIALIZED} - you may " + "be calibrating an uninitialized module which may fail or attempting " + "to re-calibrate a frozen module" + ) + + module.quantization_status = QuantizationStatus.CALIBRATION \ No newline at end of file diff --git a/src/sparsetensors/quantization/lifecycle/forward.py b/src/sparsetensors/quantization/lifecycle/forward.py new file mode 100644 index 00000000..4247e7c7 --- /dev/null +++ b/src/sparsetensors/quantization/lifecycle/forward.py @@ -0,0 +1,155 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import wraps + +import torch +from torch.nn import Module + +from sparseml.modifiers.quantization.lifecycle.status import QuantizationStatus + +from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationScheme, QuantizationArgs + +__all__ = ["wrap_module_forward_quantized"] + + +def quantize( + x: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + q_max: torch.Tensor, +) -> torch.Tensor: + return torch.clamp( + torch.round( + x / scale + zero_point, + ), + 0, + q_max, + ) + + +def dequantize( + x_q: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, +) -> torch.Tensor: + return (x_q - zero_point) * scale + + +def fake_quantize( + x: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + args: QuantizationArgs, +) -> torch.Tensor: + max_q = torch.tensor(2**args.num_bits - 1) + columns = x.shape[1] + Q = torch.zeros_like(x) + # for i1 in range(0, columns, args.block_size): + # i2 = min(i1 + args.block_size, columns) + # count = i2 - i1 + + # W1 = x[:, i1:i2].clone() + # Q1 = torch.zeros_like(W1) + + # for i in range(count): + # w = W1[:, i] + # breakpoint() + # if args.group_size != -1: + # if (i1 + i) % args.group_size == 0: + # xmin, xmax = get_qparams( + # x[:, (i1 + i) : (i1 + i + args.group_size)], args.symmetric + # ) + # scale, zero = get_scale_zero_point( + # x[:, (i1 + i) : (i1 + i + args.group_size)], + # max_q, + # xmax, + # xmin, + # args.symmetric, + # args.group_size, + # ) + + # q = quantize(w.unsqueeze(1), scale, zero, max_q).flatten() + # Q1[:, i] = q + # Q[:, i1:i2] = Q1 + Q = quantize(x, scale, zero_point, max_q) + return dequantize(Q, scale, zero_point) + + +def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme): + # expects a module already initialized and injected with the parameters in + # initialize_module_for_quantization + forward_func_orig = module.forward.__func__ + + @wraps(forward_func_orig) # ensures docstring, names, etc are propagated + def wrapped_forward(self, *args, **kwargs): + input_ = args[0] + + if scheme.input_activations is not None: + # calibrate and (fake) quantize input activations when applicable + input_ = _maybe_calibrate_or_quantize( + module, input_, "input", scheme.input_activations + ) + + if scheme.weights is not None: + # calibrate and (fake) quantize weights when applicable + self.weight.data = _maybe_calibrate_or_quantize( + module, self.weight, "weight", scheme.weights + ) + + # perform wrapped forward call + output = forward_func_orig.__get__(module, module.__class__)( + input_, *args[1:], **kwargs + ) + + if scheme.output_activations is not None: + # calibrate and (fake) quantize output activations when applicable + output = _maybe_calibrate_or_quantize( + module, output, "output", scheme.output_activations + ) + + return output + + # bind wrapped forward to module class so reference to `self` is correct + bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__) + # set forward to wrapped forward + setattr(module, "forward", bound_wrapped_forward) + + +def _maybe_calibrate_or_quantize( + module: Module, value: Module, base_name: str, args: "QuantizationArgs" +) -> torch.Tensor: + # only run quantized for the included stages + if module.quantization_status not in { + QuantizationStatus.CALIBRATION, + QuantizationStatus.FROZEN, + }: + return value + + scale = getattr(module, f"{base_name}_scale") + # zero_point = getattr(module, f"{base_name}_zero_point").data + zero_point = getattr(module, f"{base_name}_zero_point") + + print(scale, zero_point) + + if module.quantization_status == QuantizationStatus.CALIBRATION: + # get observer and get new quant params from observation + observer = getattr(module, f"{base_name}_observer") + updated_scale, updated_zero_point = observer(value) + + # update scale and zero point + scale.data = updated_scale + zero_point.data = updated_zero_point + + return fake_quantize(value, scale, zero_point, args) \ No newline at end of file diff --git a/src/sparsetensors/quantization/lifecycle/frozen.py b/src/sparsetensors/quantization/lifecycle/frozen.py new file mode 100644 index 00000000..d480465b --- /dev/null +++ b/src/sparsetensors/quantization/lifecycle/frozen.py @@ -0,0 +1,37 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from torch.nn import Module + +from sparseml.modifiers.quantization.lifecycle.status import QuantizationStatus + + +__all__ = [ + "freeze_module_quantization", +] + + +def freeze_module_quantization(module: Module): + if not getattr(module, "quantization_scheme", None): + # no quantization scheme nothing to do + return + + # delete observers from module + for submodule_name, _ in module.named_modules(): + if "." not in submodule_name and submodule_name.endswith("_observer"): + # delete any observers that belong directly to this module + delattr(module, submodule_name) + + module.quantization_status = QuantizationStatus.FROZEN diff --git a/src/sparsetensors/quantization/lifecycle/initialize.py b/src/sparsetensors/quantization/lifecycle/initialize.py new file mode 100644 index 00000000..cfa4aa77 --- /dev/null +++ b/src/sparsetensors/quantization/lifecycle/initialize.py @@ -0,0 +1,77 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +import torch +from torch.nn import Module, Parameter + +from sparseml.modifiers.quantization.lifecycle.forward import ( + wrap_module_forward_quantized, +) +from sparseml.modifiers.quantization.lifecycle.status import QuantizationStatus +from sparseml.modifiers.quantization.utils.quantization_scheme import ( + QuantizationArgs, + QuantizationScheme, +) + + +__all__ = [ + "initialize_module_for_quantization", +] + + +_LOGGER = logging.getLogger(__name__) + + +def initialize_module_for_quantization(module: Module, scheme: QuantizationScheme): + if scheme.input_activations is not None: + + _initialize_scale_zero_point_observer( + module, "input", scheme.input_activations + ) + if scheme.weights is not None: + if hasattr(module, "weight"): + _initialize_scale_zero_point_observer(module, "weight", scheme.weights) + else: + _LOGGER.warning( + f"module type {type(module)} targeted for weight quantization but " + "has no attribute weight, skipping weight quantization " + f"for {type(module)}" + ) + if scheme.output_activations is not None: + _initialize_scale_zero_point_observer(module, "output", scheme.output_activations) + + module.quantization_scheme = scheme + module.quantization_status = QuantizationStatus.INITIALIZED + + # wrap forward call of module to perform quantized actions based on calltime status + wrap_module_forward_quantized(module, scheme) + + + +def _initialize_scale_zero_point_observer( + module: Module, base_name: str, quantization_args: QuantizationArgs +): + # initializes empty scale and zero point parameters for the module + init_scale = Parameter(torch.empty(0), requires_grad=False) + module.register_parameter(f"{base_name}_scale", init_scale) + + init_zero_point = Parameter(torch.empty(0, dtype=int), requires_grad=False) + module.register_parameter(f"{base_name}_zero_point", init_zero_point) + + # initialize observer module and attach as submodule + observer = quantization_args.get_observer() + module.register_module(f"{base_name}_observer", observer) diff --git a/src/sparsetensors/quantization/lifecycle/status.py b/src/sparsetensors/quantization/lifecycle/status.py new file mode 100644 index 00000000..3b6a441d --- /dev/null +++ b/src/sparsetensors/quantization/lifecycle/status.py @@ -0,0 +1,26 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +__all__ = [ + "QuantizationStatus", +] + + +class QuantizationStatus(Enum): + INITIALIZED = "INITIALIZED" + CALIBRATION = "CALIBRATION" + FROZEN = "FROZEN" diff --git a/src/sparsetensors/quantization/observers/__init__.py b/src/sparsetensors/quantization/observers/__init__.py new file mode 100644 index 00000000..1bec545d --- /dev/null +++ b/src/sparsetensors/quantization/observers/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa + +from .base import * +from .memoryless import * +from .min_max import * \ No newline at end of file diff --git a/src/sparsetensors/quantization/observers/base.py b/src/sparsetensors/quantization/observers/base.py new file mode 100644 index 00000000..44c8ec37 --- /dev/null +++ b/src/sparsetensors/quantization/observers/base.py @@ -0,0 +1,72 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +from torch import FloatTensor, IntTensor, Tensor +from torch.nn import Module + +from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationArgs +from sparsezoo.utils.registry import RegistryMixin + + +__all__ = ["Observer"] + + +class Observer(Module, RegistryMixin): + """ + Base Observer class to be subclassed for specific implementation. + Subclasses should override `calculate_qparams` to return a scale, zero_point + pair + """ + + def __init__(self, + quantization_args: QuantizationArgs + ): + self.quantization_args: QuantizationArgs = quantization_args + super().__init__() + self._scale = None + self._zero_point = None + + def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: + """ + maps directly to get_qparams + :param observed: optional observed tensor to calculate quantization parameters + from + :return: tuple of scale and zero point based on last observed value + """ + return self.get_qparams(observed=observed) + + def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: + """ + :param observed: observed tensor to calculate quantization parameters for + :return: tuple of scale and zero point derived from the observed tensor + """ + raise NotImplementedError(f"{self.__class__} must implement calculate_qparams") + + def get_qparams( + self, observed: Optional[Tensor] = None + ) -> Tuple[FloatTensor, IntTensor]: + """ + Convenience function to wrap overwritten calculate_qparams + adds support to make observed tensor optional and support for tracking latest + calculated scale and zero point + :param observed: optional observed tensor to calculate quantization parameters + from + :return: tuple of scale and zero point based on last observed value + """ + if observed is not None: + # re-calcualte scale and zero point, update the stored value + self._scale, self._zero_point = self.calculate_qparams(observed) + return self._scale, self._zero_point \ No newline at end of file diff --git a/src/sparsetensors/quantization/observers/memoryless.py b/src/sparsetensors/quantization/observers/memoryless.py new file mode 100644 index 00000000..e69de29b diff --git a/src/sparsetensors/quantization/observers/min_max.py b/src/sparsetensors/quantization/observers/min_max.py new file mode 100644 index 00000000..c72eb1c0 --- /dev/null +++ b/src/sparsetensors/quantization/observers/min_max.py @@ -0,0 +1,80 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +from torch import FloatTensor, IntTensor, Tensor + +from sparseml.modifiers.quantization.observers.base import Observer +from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationArgs + + +__all__ = ["MinMaxObserver"] + + +@Observer.register("minmax") +class MinMaxObserver(Observer): + """ + Implements a dynamic quantization observer that sets the scale and + zero point based on the latest observed value + """ + + def __init__(self, quantization_args: QuantizationArgs): + super().__init__(quantization_args=quantization_args) + + self.min_val = float("inf") + self.max_val = -float("inf") + self.counter = 0 + + def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: + """ + :param observed: observed tensor to calculate quantization parameters for + :return: tuple of scale and zero point derived from the observed tensor + """ + # TODO: Add support for full range of quantization Args, only supports 8bit + # per tensor + bit_range = 255 + min_val = torch.tensor([observed.min()]) + max_val = torch.tensor([observed.max()]) + + # running average + if self.counter > 0: + self.min_val = (self.min_val * self.counter + min_val) / (self.counter + 1) + self.max_val = (self.max_val * self.counter + max_val) / (self.counter + 1) + else: + self.min_val = min_val + self.max_val = max_val + + # ensure that the zeros are in the range + self.min_val = torch.min(self.min_val, torch.zeros_like(self.min_val)) + self.max_val = torch.max(self.max_val, torch.zeros_like(self.max_val)) + + self.counter += 1 + + if self.quantization_args.symmetric: + symmetric_range = 2 * max(self.min_val.abs(), self.max_val.abs()) + scale = symmetric_range / bit_range + zero_point = torch.tensor(0).to(torch.int8) + else: + # non-symmetric + observed_range = self.max_val - self.min_val + scale = observed_range / bit_range + + # scales from a 0 range should be set to 1 + scale[observed_range == 0] = 1 + + zero_point = (0 - self.min_val) / scale + + return scale, zero_point \ No newline at end of file diff --git a/src/sparsetensors/quantization/utils/quantization_scheme.py b/src/sparsetensors/quantization/utils/quantization_scheme.py new file mode 100644 index 00000000..976b534e --- /dev/null +++ b/src/sparsetensors/quantization/utils/quantization_scheme.py @@ -0,0 +1,391 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Schemas and types to support quantization +""" +from copy import deepcopy +from functools import partial +from typing import Any, Dict, Optional, Union + +import torch +from packaging import version +from pydantic import BaseModel, Field, validator +from torch.nn import Identity + + +try: + from torch import quantization as torch_quantization +except Exception: + torch_quantization = None + +from sparseml.modifiers.quantization.utils.fake_quant_wrapper import FakeQuantizeWrapper + + +__all__ = [ + "DictQuantizationArgs", + "DictQuantizationScheme", + "QuantizationArgs", + "QuantizationScheme", + "QuantizationSchemeLoadable", + "compute_range", + "get_observer", +] + + +_PARSED_TORCH_VERSION = version.parse(torch.__version__) +_TORCH_PRE_112 = _PARSED_TORCH_VERSION < version.parse("1.12.0") + + +""" +Type definition aliases for defining QuantizationArgs and QuantizationScheme +as dictionaries for YAML serialization +""" +DictQuantizationArgs = Dict[str, Union[int, bool, Dict[str, Any]]] +DictQuantizationScheme = Dict[str, DictQuantizationArgs] + +""" +Type definition for a type that is valid for loading a QuantizationScheme +using QuantizationScheme.load +""" +QuantizationSchemeLoadable = Union[ + "QuantizationScheme", + DictQuantizationScheme, + str, + None, +] + + +class QuantizationArgs(BaseModel): + """ + Class representing user facing arguments to define quantization Observers of + activations or weights in a network + """ + + num_bits: int = Field( + default=8, description="number of bits to target for quantization" + ) + symmetric: bool = Field( + default=False, + description="set True to use symmetric quantization. Default False", + ) + strategy: str = Field( + default="tensor", + description=( + "scope of the quantization to be applied. can be 'tensor' or 'channel'" + ), + ) + observer: str = Field( + default="minmax", + description=( + "The class to use to compute the quantization params - scale and zero-point'" + ), + + ) + + # kwargs: Dict[str, Any] = Field( + # default_factory=dict, + # description=( + # "optional dict of kwargs to be passed directly to torch quantization " + # "Observers constructor excluding quantization range or symmetry" + # ), + # ) + observer_kwargs: Dict[str, Any] = Field( + default_factory=dict, + description=( + "optional dict of kwargs to be passed directly to torch quantization " + "Observers constructor excluding quantization range or symmetry" + ), + ) + + @classmethod + def default_activation_args(cls): + """ + :return: default 8 bits asymmetric settings + """ + return cls(num_bits=8, symmetric=False) + + @classmethod + def default_weight_args(cls): + """ + :return: default 8 bits symmetric settings + """ + return cls(num_bits=8, symmetric=True) + + def get_observer(self) -> "torch.quantization.FakeQuantize": + """ + :return: torch quantization FakeQuantize built based on these QuantizationArgs + """ + from sparseml.modifiers.quantization.observers.base import Observer + return Observer.load_from_registry(self.observer, quantization_args=self) + + @validator("strategy") + def validate_strategy(cls, value): + valid_scopes = ["tensor", "channel"] + if value not in valid_scopes: + raise ValueError(f"`strategy` must be one of {valid_scopes}, got {value}") + return value + + +class QuantizationScheme(BaseModel): + """ + Class composed of QuantizationArgs to build QConfig and QuantWrapper objects for + quantizing models. Provides a simple user interface for defining how inputs, + weights, and outputs should be quantized + """ + + def __init__(self, *args, **kwargs): + # support for loading from yaml str + args = [arg if arg != "null" else None for arg in args] + for key, val in kwargs.items(): + if val == "null": + kwargs[key] = None + super().__init__(*args, **kwargs) + + input_activations: Optional[QuantizationArgs] = Field( + default_factory=QuantizationArgs.default_activation_args, + description=( + "target quantization setting for input activations. Set to None to " + "not quantize input activations. Default is 8 bits asymmetric" + ), + ) + weights: Optional[QuantizationArgs] = Field( + default_factory=QuantizationArgs.default_weight_args, + description=( + "target quantization setting for model weights. Set to None to " + "not quantize weights. Default is 8 bits symmetric" + ), + ) + output_activations: Optional[QuantizationArgs] = Field( + default=None, + description=( + "target quantization setting for output activations. Set to None to " + "not quantize output activations. Default is None" + ), + ) + target_hardware: Optional[str] = Field( + default=None, + description=( + "target deployment runtime/hardware name to be set by default " + "classmethods. Default is None" + ), + ) + + @classmethod + def load( + cls, + scheme: QuantizationSchemeLoadable, + default: Optional["QuantizationScheme"] = None, + ) -> "QuantizationScheme": + """ + :param scheme: QuantizationScheme, dict representation of scheme, + or string alias of a scheme to load. Valid strings: + ['default', 'deepsparse', 'tensorrt'] + :param default: default QuantizationScheme to override 'default' scheme + with + :return: constructed QuantizationScheme object from the given scheme; + if given a dict, returns QuantizationScheme.parse_obj(scheme), string + input will return the defualt QuantizationScheme if set to 'default'. + """ + if isinstance(scheme, cls): + return scheme + elif scheme is None or scheme == "default": + # if no default override, defaults to QuantizationScheme() + return deepcopy(default) or cls() + elif isinstance(scheme, str): + if scheme == "deepsparse": + return cls.deepsparse() + elif scheme == "tensorrt": + return cls.tensorrt() + raise ValueError( + f"Unrecognized QuantizationScheme string alias {scheme}. " + "Valid strings: ['default', 'deepsparse', 'tensorrt']" + ) + elif isinstance(scheme, dict): + # default to dict + scheme = {key: _parse_quantization_arg(arg) for key, arg in scheme.items()} + return cls.parse_obj(scheme) + else: + raise ValueError( + f"Unrecognized type {type(scheme)} for QuantizationScheme.load, " + "expected one of: [QuantizationScheme, Dict, str, None]" + ) + + @classmethod + def deepsparse(cls) -> "QuantizationScheme": + """ + :return: QuantizationScheme for deepsparse targeted deployments - + int8, symmetric weights, asymmetric inputs, no output quantization + """ + return cls( + input_activations=QuantizationArgs(num_bits=8, symmetric=False), + weights=QuantizationArgs(num_bits=8, symmetric=True), + output_activations=None, + target_hardware="deepsparse", + ) + + @classmethod + def tensorrt(cls) -> "QuantizationScheme": + """ + :return: QuantizationScheme for tensorrt targeted deployments - + compatibility with explict quantization as supported by TensorRT 8.2: + int8, symmetric for both weights and inputs, no output quantization + """ + return cls( + input_activations=QuantizationArgs(num_bits=8, symmetric=True), + weights=QuantizationArgs(num_bits=8, symmetric=True), + output_activations=None, + target_hardware="tensorrt", + ) + + def get_qconfig(self) -> "torch.quantization.QConfig": + """ + :return: QConfig for Modules (output activations used, + use QuantWrapper for inputs) + """ + qconfig = _get_qconfig(self.output_activations, self.weights) + # add reference to this quantization scheme for reference + qconfig.quantization_scheme = self + return qconfig + + def get_wrapper_qconfig(self) -> "torch.quantization.QConfig": + """ + :return: QConfig for QuantWrapper objects (input activations used) + """ + qconfig = _get_qconfig(self.input_activations, None) + # add reference to this quantization scheme for reference + qconfig.quantization_scheme = self + return qconfig + + def __str__(self) -> str: + """ + :return: YAML friendly string serialization + """ + dict_repr = self.dict() + dict_repr = { + key: val if val is not None else "null" for key, val in dict_repr.items() + } + return str(dict_repr) + + +def compute_range(dtype: torch.dtype, bits: int): + """ + compute quantization limits depending on data type and number of bits + + :param dtype: data type. + :param bits: number of bits. + :return: minimum limit, maximum limit, whether the range is customized + """ + bits = bits if bits else 8 + is_custom = bits != 8 + if dtype == torch.qint8: + quant_min = -(2 ** (bits - 1)) + quant_max = (2 ** (bits - 1)) - 1 + elif dtype == torch.quint8: + quant_min = 0 + quant_max = (2**bits) - 1 + + return quant_min, quant_max, is_custom + + +# def get_observer( +# symmetric: bool, +# strategy: str, +# dtype: torch.dtype, +# bits: int, +# reduce_range: bool, +# qconfig_kwargs: Dict[str, Any], +# ): +# quant_min, quant_max, is_custom_qrange = compute_range(dtype, bits) + +# if strategy == "channel": +# qscheme = torch.per_channel_symmetric if symmetric else torch.per_channel_affine +# observer_cls = torch_quantization.MovingAveragePerChannelMinMaxObserver +# observer_kwargs = dict( +# ch_axis=0, +# dtype=dtype, +# qscheme=qscheme, +# reduce_range=reduce_range, +# ) +# else: # default to tensor strategy +# qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine +# observer_cls = torch_quantization.MovingAverageMinMaxObserver +# observer_kwargs = dict( +# dtype=dtype, +# qscheme=qscheme, +# reduce_range=reduce_range, +# ) +# """ +# in torch 1.9.1, quant_min and quant_max are not passed to observer: +# https://github.com/pytorch/pytorch/blob/v1.9.1/torch/quantization/fake_quantize.py#L109 +# however in 1.12.0, this is fixed so both are passed to observer: +# https://github.com/pytorch/pytorch/blob/v1.12.1/torch/ao/quantization/fake_quantize.py#L132 + +# Passing quant_min/quant_max to observer means the observer will have +# `self.has_customized_qrange == True` in both 1.9.1 and 1.12.0. + +# For whatever reason, both versions calculate zero point for +# quint8 differently **if there is a customized_qrange** +# 1. customized qrange has zero point of 127 +# 2. non-customized has zero point of 128. +# source: +# https://github.com/pytorch/pytorch/blob/v1.12.1/torch/ao/quantization/observer.py#L293 + +# **we want to ensure that the zero point is 128** +# see https://github.com/neuralmagic/sparseml/pull/604 +# """ +# if is_custom_qrange: +# # for both versions we need to include the custom min/max values in kwargs +# observer_kwargs["quant_min"] = quant_min +# observer_kwargs["quant_max"] = quant_max +# if _TORCH_PRE_112: +# # pre 1.12, the observer doesn't get passed the quant_min/quant_max values, +# # so we patch them in to the constructor of the observer +# observer_cls = partial( +# observer_cls, quant_min=quant_min, quant_max=quant_max +# ) +# else: +# # if using a non custom qrange, we can rely on default values used by +# # the observers +# if _TORCH_PRE_112: +# # pre 1.12, the observer doesn't get passed the quant_min/quant_max values, +# # so we are safe to pass these to FakeQuantize +# observer_kwargs["quant_min"] = quant_min +# observer_kwargs["quant_max"] = quant_max +# else: +# # post 1.12 we cannot pass them to the observer since that will set +# # has_customized_qrange. instead we rely on the default values +# # being equal to the `quant_min` and `quant_max` here. +# pass + +# observer_kwargs["observer"] = observer_cls +# observer_kwargs.update(qconfig_kwargs or {}) +# observer = FakeQuantizeWrapper.with_args(**observer_kwargs) + +# return observer + + +def _get_qconfig( + activation_args: Optional[QuantizationArgs], weight_args: Optional[QuantizationArgs] +) -> "torch.quantization.QConfig": + return torch_quantization.QConfig( + activation=activation_args.get_observer() if activation_args else Identity, + weight=weight_args.get_observer() if weight_args else Identity, + ) + + +def _parse_quantization_arg(arg: Any): + if arg == "None": + return None + return arg \ No newline at end of file From 560ef13a372765ffac5fc1d9aa21049289d85b21 Mon Sep 17 00:00:00 2001 From: george ohashi Date: Fri, 12 Apr 2024 16:46:14 +0000 Subject: [PATCH 2/6] add memoryless --- .../quantization/observers/memoryless.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/src/sparsetensors/quantization/observers/memoryless.py b/src/sparsetensors/quantization/observers/memoryless.py index e69de29b..5f74448b 100644 --- a/src/sparsetensors/quantization/observers/memoryless.py +++ b/src/sparsetensors/quantization/observers/memoryless.py @@ -0,0 +1,63 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +from torch import FloatTensor, IntTensor, Tensor + +from sparseml.modifiers.quantization.observers.base import Observer +# from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationArgs + + +__all__ = ["MemorylessObserver"] + + +@Observer.register("memoryless") +class MemorylessObserver(Observer): + """ + Implements a dynamic quantization observer that sets the scale and + zero point based on the latest observed value + """ + + def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: + """ + :param observed: observed tensor to calculate quantization parameters for + :return: tuple of scale and zero point derived from the observed tensor + """ + # TODO: Add support for full range of quantization Args, only supports 8bit + # per tensor + bit_range = 255 + min_val = observed.min() + max_val = observed.max() + + # ensure zero is in the range + min_val = torch.min(min_val, torch.zeros_like(min_val)) + max_val = torch.max(max_val, torch.zeros_like(max_val)) + + if self.quantization_args.symmetric: + symmetric_range = 2 * max(min_val.abs(), max_val.abs()) + scale = symmetric_range / bit_range + zero_point = torch.tensor(0).to(torch.int8) + else: + # non-symmetric + observed_range = max_val - min_val + scale = observed_range / bit_range + + # scales from a 0 range should be set to 1 + scale[observed_range == 0] = 1 + + zero_point = (0 - min_val) / scale + + return scale, zero_point \ No newline at end of file From 0804be30be70a5578e8dc3d5a518c3fd90cac1c1 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Fri, 12 Apr 2024 17:13:26 +0000 Subject: [PATCH 3/6] run bin.quant --- bin/quant.py | 27 +- setup.py | 2 +- .../quantization/lifecycle/__init__.py | 1 - .../{calirbation.py => calibration.py} | 5 +- .../quantization/lifecycle/forward.py | 21 +- .../quantization/lifecycle/frozen.py | 3 +- .../quantization/lifecycle/initialize.py | 27 +- .../quantization/observers/__init__.py | 2 +- .../quantization/observers/base.py | 12 +- .../quantization/observers/memoryless.py | 7 +- .../quantization/observers/min_max.py | 7 +- src/sparsetensors/quantization/quant_args.py | 25 +- .../quantization/utils/quantization_scheme.py | 391 ------------------ 13 files changed, 82 insertions(+), 448 deletions(-) rename src/sparsetensors/quantization/lifecycle/{calirbation.py => calibration.py} (90%) delete mode 100644 src/sparsetensors/quantization/utils/quantization_scheme.py diff --git a/bin/quant.py b/bin/quant.py index 3d33191f..94bfc448 100644 --- a/bin/quant.py +++ b/bin/quant.py @@ -1,17 +1,18 @@ import torch from torch.nn import Linear -# from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationScheme, QuantizationArgs + from sparsetensors.quantization.quant_args import QuantizationArgs from sparsetensors.quantization.quant_scheme import QuantizationScheme -from sparseml.modifiers.quantization.lifecycle.initialize import initialize_module_for_quantization -from sparseml.modifiers.quantization.lifecycle.calibration import set_module_for_calibration -from sparseml.modifiers.quantization.lifecycle.frozen import freeze_module_quantization +from sparsetensors.quantization.lifecycle.initialize import initialize_module_for_quantization +from sparsetensors.quantization.lifecycle.calibration import set_module_for_calibration +from sparsetensors.quantization.lifecycle.frozen import freeze_module_quantization num_bits = 8 scheme = QuantizationScheme( input_acivations=QuantizationArgs(num_bits=num_bits, symmetric=False), weights=QuantizationArgs(num_bits=num_bits, symmetric=True), output_activations=None, + targets = ["*"], ) layer = Linear(4, 4) @@ -31,25 +32,29 @@ layer(torch.randn(4,4)) print(dict(layer.named_parameters())) # scale and zero point should have updated values print(2) -for _ in range(10): +print("calib layers ") +for i in range(10): + print("iter", i) layer(torch.randn(4,4)) print(dict(layer.named_parameters())) # scale and zero point should have updated values again since we did another pass print(3) -breakpoint() +# breakpoint() freeze_module_quantization(layer) -for _ in range(10): +print("freeze layers ") +for i in range(10): # do more forward passes but show args are frozen - layer(torch.random.randn(4,4)) + print("iter", i) + layer(torch.randn(4,4)) print(dict(layer.named_parameters())) # scale and zero point should not be updated now -# missing +# # missing -# correctness -# quantizing an entire model +# # correctness +# # quantizing an entire model diff --git a/setup.py b/setup.py index 12ef67a3..de506c99 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ def _setup_install_requires() -> List: return ["torch>=1.7.0", "transformers<=4.40", "pydantic<2.7"] def _setup_extras() -> Dict: - return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0"]} + return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0", "sparsezoo"]} setup( name="sparsetensors", diff --git a/src/sparsetensors/quantization/lifecycle/__init__.py b/src/sparsetensors/quantization/lifecycle/__init__.py index d90b28a9..52b86440 100644 --- a/src/sparsetensors/quantization/lifecycle/__init__.py +++ b/src/sparsetensors/quantization/lifecycle/__init__.py @@ -19,4 +19,3 @@ from .frozen import * from .initialize import * from .status import * -from .initialize import * diff --git a/src/sparsetensors/quantization/lifecycle/calirbation.py b/src/sparsetensors/quantization/lifecycle/calibration.py similarity index 90% rename from src/sparsetensors/quantization/lifecycle/calirbation.py rename to src/sparsetensors/quantization/lifecycle/calibration.py index a4f4dfea..986b062a 100644 --- a/src/sparsetensors/quantization/lifecycle/calirbation.py +++ b/src/sparsetensors/quantization/lifecycle/calibration.py @@ -15,10 +15,9 @@ import logging +from sparsetensors.quantization.lifecycle.status import QuantizationStatus from torch.nn import Module -from sparseml.modifiers.quantization.lifecycle.status import QuantizationStatus - __all__ = [ "set_module_for_calibration", @@ -41,4 +40,4 @@ def set_module_for_calibration(module: Module): "to re-calibrate a frozen module" ) - module.quantization_status = QuantizationStatus.CALIBRATION \ No newline at end of file + module.quantization_status = QuantizationStatus.CALIBRATION diff --git a/src/sparsetensors/quantization/lifecycle/forward.py b/src/sparsetensors/quantization/lifecycle/forward.py index 4247e7c7..cbb27dea 100644 --- a/src/sparsetensors/quantization/lifecycle/forward.py +++ b/src/sparsetensors/quantization/lifecycle/forward.py @@ -15,11 +15,16 @@ from functools import wraps import torch +from sparsetensors.quantization.lifecycle.status import QuantizationStatus + +# from sparsetensors.quantization.utils.quantization_scheme import ( +# QuantizationArgs, +# QuantizationScheme, +# ) +from sparsetensors.quantization.quant_args import QuantizationArgs +from sparsetensors.quantization.quant_scheme import QuantizationScheme from torch.nn import Module -from sparseml.modifiers.quantization.lifecycle.status import QuantizationStatus - -from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationScheme, QuantizationArgs __all__ = ["wrap_module_forward_quantized"] @@ -34,8 +39,8 @@ def quantize( torch.round( x / scale + zero_point, ), - 0, - q_max, + 0, + q_max, ) @@ -83,7 +88,7 @@ def fake_quantize( # q = quantize(w.unsqueeze(1), scale, zero, max_q).flatten() # Q1[:, i] = q # Q[:, i1:i2] = Q1 - Q = quantize(x, scale, zero_point, max_q) + Q = quantize(x, scale, zero_point, max_q) return dequantize(Q, scale, zero_point) @@ -138,7 +143,7 @@ def _maybe_calibrate_or_quantize( return value scale = getattr(module, f"{base_name}_scale") - # zero_point = getattr(module, f"{base_name}_zero_point").data + # zero_point = getattr(module, f"{base_name}_zero_point").data zero_point = getattr(module, f"{base_name}_zero_point") print(scale, zero_point) @@ -152,4 +157,4 @@ def _maybe_calibrate_or_quantize( scale.data = updated_scale zero_point.data = updated_zero_point - return fake_quantize(value, scale, zero_point, args) \ No newline at end of file + return fake_quantize(value, scale, zero_point, args) diff --git a/src/sparsetensors/quantization/lifecycle/frozen.py b/src/sparsetensors/quantization/lifecycle/frozen.py index d480465b..6b92eee7 100644 --- a/src/sparsetensors/quantization/lifecycle/frozen.py +++ b/src/sparsetensors/quantization/lifecycle/frozen.py @@ -13,10 +13,9 @@ # limitations under the License. +from sparsetensors.quantization.lifecycle.status import QuantizationStatus from torch.nn import Module -from sparseml.modifiers.quantization.lifecycle.status import QuantizationStatus - __all__ = [ "freeze_module_quantization", diff --git a/src/sparsetensors/quantization/lifecycle/initialize.py b/src/sparsetensors/quantization/lifecycle/initialize.py index cfa4aa77..6d23f4cc 100644 --- a/src/sparsetensors/quantization/lifecycle/initialize.py +++ b/src/sparsetensors/quantization/lifecycle/initialize.py @@ -16,17 +16,17 @@ import logging import torch +from sparsetensors.quantization.lifecycle.forward import wrap_module_forward_quantized +from sparsetensors.quantization.lifecycle.status import QuantizationStatus + +# from sparsetensors.quantization.utils.quantization_scheme import ( +# QuantizationArgs, +# QuantizationScheme, +# ) +from sparsetensors.quantization.quant_args import QuantizationArgs +from sparsetensors.quantization.quant_scheme import QuantizationScheme from torch.nn import Module, Parameter -from sparseml.modifiers.quantization.lifecycle.forward import ( - wrap_module_forward_quantized, -) -from sparseml.modifiers.quantization.lifecycle.status import QuantizationStatus -from sparseml.modifiers.quantization.utils.quantization_scheme import ( - QuantizationArgs, - QuantizationScheme, -) - __all__ = [ "initialize_module_for_quantization", @@ -39,9 +39,7 @@ def initialize_module_for_quantization(module: Module, scheme: QuantizationScheme): if scheme.input_activations is not None: - _initialize_scale_zero_point_observer( - module, "input", scheme.input_activations - ) + _initialize_scale_zero_point_observer(module, "input", scheme.input_activations) if scheme.weights is not None: if hasattr(module, "weight"): _initialize_scale_zero_point_observer(module, "weight", scheme.weights) @@ -52,7 +50,9 @@ def initialize_module_for_quantization(module: Module, scheme: QuantizationSchem f"for {type(module)}" ) if scheme.output_activations is not None: - _initialize_scale_zero_point_observer(module, "output", scheme.output_activations) + _initialize_scale_zero_point_observer( + module, "output", scheme.output_activations + ) module.quantization_scheme = scheme module.quantization_status = QuantizationStatus.INITIALIZED @@ -61,7 +61,6 @@ def initialize_module_for_quantization(module: Module, scheme: QuantizationSchem wrap_module_forward_quantized(module, scheme) - def _initialize_scale_zero_point_observer( module: Module, base_name: str, quantization_args: QuantizationArgs ): diff --git a/src/sparsetensors/quantization/observers/__init__.py b/src/sparsetensors/quantization/observers/__init__.py index 1bec545d..d0362b8f 100644 --- a/src/sparsetensors/quantization/observers/__init__.py +++ b/src/sparsetensors/quantization/observers/__init__.py @@ -16,4 +16,4 @@ from .base import * from .memoryless import * -from .min_max import * \ No newline at end of file +from .min_max import * diff --git a/src/sparsetensors/quantization/observers/base.py b/src/sparsetensors/quantization/observers/base.py index 44c8ec37..00cd7561 100644 --- a/src/sparsetensors/quantization/observers/base.py +++ b/src/sparsetensors/quantization/observers/base.py @@ -14,12 +14,12 @@ from typing import Optional, Tuple +# from sparsetensors.quantization.utils.quantization_scheme import QuantizationArgs +from sparsetensors.quantization.quant_args import QuantizationArgs +from sparsezoo.utils.registry import RegistryMixin from torch import FloatTensor, IntTensor, Tensor from torch.nn import Module -from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationArgs -from sparsezoo.utils.registry import RegistryMixin - __all__ = ["Observer"] @@ -31,9 +31,7 @@ class Observer(Module, RegistryMixin): pair """ - def __init__(self, - quantization_args: QuantizationArgs - ): + def __init__(self, quantization_args: QuantizationArgs): self.quantization_args: QuantizationArgs = quantization_args super().__init__() self._scale = None @@ -69,4 +67,4 @@ def get_qparams( if observed is not None: # re-calcualte scale and zero point, update the stored value self._scale, self._zero_point = self.calculate_qparams(observed) - return self._scale, self._zero_point \ No newline at end of file + return self._scale, self._zero_point diff --git a/src/sparsetensors/quantization/observers/memoryless.py b/src/sparsetensors/quantization/observers/memoryless.py index 5f74448b..faabbb5a 100644 --- a/src/sparsetensors/quantization/observers/memoryless.py +++ b/src/sparsetensors/quantization/observers/memoryless.py @@ -15,10 +15,11 @@ from typing import Tuple import torch +from sparsetensors.quantization.observers.base import Observer from torch import FloatTensor, IntTensor, Tensor -from sparseml.modifiers.quantization.observers.base import Observer -# from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationArgs + +# from sparsetensors.quantization.utils.quantization_scheme import QuantizationArgs __all__ = ["MemorylessObserver"] @@ -60,4 +61,4 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: zero_point = (0 - min_val) / scale - return scale, zero_point \ No newline at end of file + return scale, zero_point diff --git a/src/sparsetensors/quantization/observers/min_max.py b/src/sparsetensors/quantization/observers/min_max.py index c72eb1c0..40cde72c 100644 --- a/src/sparsetensors/quantization/observers/min_max.py +++ b/src/sparsetensors/quantization/observers/min_max.py @@ -15,11 +15,10 @@ from typing import Tuple import torch +from sparsetensors.quantization.observers.base import Observer +from sparsetensors.quantization.quant_args import QuantizationArgs from torch import FloatTensor, IntTensor, Tensor -from sparseml.modifiers.quantization.observers.base import Observer -from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationArgs - __all__ = ["MinMaxObserver"] @@ -77,4 +76,4 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: zero_point = (0 - self.min_val) / scale - return scale, zero_point \ No newline at end of file + return scale, zero_point diff --git a/src/sparsetensors/quantization/quant_args.py b/src/sparsetensors/quantization/quant_args.py index 89a2e3df..fb9e9b01 100644 --- a/src/sparsetensors/quantization/quant_args.py +++ b/src/sparsetensors/quantization/quant_args.py @@ -13,9 +13,9 @@ # limitations under the License. from enum import Enum -from typing import Optional +from typing import Any, Dict, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field __all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"] @@ -61,3 +61,24 @@ class QuantizationArgs(BaseModel): strategy: QuantizationStrategy = QuantizationStrategy.TENSOR group_size: Optional[int] = None block_structure: Optional[str] = None + observer: str = Field( + default="minmax", + description=( + "The class to use to compute the quantization params - scale and zero-point'" + ), + ) + observer_kwargs: Dict[str, Any] = Field( + default_factory=dict, + description=( + "optional dict of kwargs to be passed directly to torch quantization " + "Observers constructor excluding quantization range or symmetry" + ), + ) + + def get_observer(self): + """ + :return: torch quantization FakeQuantize built based on these QuantizationArgs + """ + from sparsetensors.quantization.observers.base import Observer + + return Observer.load_from_registry(self.observer, quantization_args=self) diff --git a/src/sparsetensors/quantization/utils/quantization_scheme.py b/src/sparsetensors/quantization/utils/quantization_scheme.py deleted file mode 100644 index 976b534e..00000000 --- a/src/sparsetensors/quantization/utils/quantization_scheme.py +++ /dev/null @@ -1,391 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Schemas and types to support quantization -""" -from copy import deepcopy -from functools import partial -from typing import Any, Dict, Optional, Union - -import torch -from packaging import version -from pydantic import BaseModel, Field, validator -from torch.nn import Identity - - -try: - from torch import quantization as torch_quantization -except Exception: - torch_quantization = None - -from sparseml.modifiers.quantization.utils.fake_quant_wrapper import FakeQuantizeWrapper - - -__all__ = [ - "DictQuantizationArgs", - "DictQuantizationScheme", - "QuantizationArgs", - "QuantizationScheme", - "QuantizationSchemeLoadable", - "compute_range", - "get_observer", -] - - -_PARSED_TORCH_VERSION = version.parse(torch.__version__) -_TORCH_PRE_112 = _PARSED_TORCH_VERSION < version.parse("1.12.0") - - -""" -Type definition aliases for defining QuantizationArgs and QuantizationScheme -as dictionaries for YAML serialization -""" -DictQuantizationArgs = Dict[str, Union[int, bool, Dict[str, Any]]] -DictQuantizationScheme = Dict[str, DictQuantizationArgs] - -""" -Type definition for a type that is valid for loading a QuantizationScheme -using QuantizationScheme.load -""" -QuantizationSchemeLoadable = Union[ - "QuantizationScheme", - DictQuantizationScheme, - str, - None, -] - - -class QuantizationArgs(BaseModel): - """ - Class representing user facing arguments to define quantization Observers of - activations or weights in a network - """ - - num_bits: int = Field( - default=8, description="number of bits to target for quantization" - ) - symmetric: bool = Field( - default=False, - description="set True to use symmetric quantization. Default False", - ) - strategy: str = Field( - default="tensor", - description=( - "scope of the quantization to be applied. can be 'tensor' or 'channel'" - ), - ) - observer: str = Field( - default="minmax", - description=( - "The class to use to compute the quantization params - scale and zero-point'" - ), - - ) - - # kwargs: Dict[str, Any] = Field( - # default_factory=dict, - # description=( - # "optional dict of kwargs to be passed directly to torch quantization " - # "Observers constructor excluding quantization range or symmetry" - # ), - # ) - observer_kwargs: Dict[str, Any] = Field( - default_factory=dict, - description=( - "optional dict of kwargs to be passed directly to torch quantization " - "Observers constructor excluding quantization range or symmetry" - ), - ) - - @classmethod - def default_activation_args(cls): - """ - :return: default 8 bits asymmetric settings - """ - return cls(num_bits=8, symmetric=False) - - @classmethod - def default_weight_args(cls): - """ - :return: default 8 bits symmetric settings - """ - return cls(num_bits=8, symmetric=True) - - def get_observer(self) -> "torch.quantization.FakeQuantize": - """ - :return: torch quantization FakeQuantize built based on these QuantizationArgs - """ - from sparseml.modifiers.quantization.observers.base import Observer - return Observer.load_from_registry(self.observer, quantization_args=self) - - @validator("strategy") - def validate_strategy(cls, value): - valid_scopes = ["tensor", "channel"] - if value not in valid_scopes: - raise ValueError(f"`strategy` must be one of {valid_scopes}, got {value}") - return value - - -class QuantizationScheme(BaseModel): - """ - Class composed of QuantizationArgs to build QConfig and QuantWrapper objects for - quantizing models. Provides a simple user interface for defining how inputs, - weights, and outputs should be quantized - """ - - def __init__(self, *args, **kwargs): - # support for loading from yaml str - args = [arg if arg != "null" else None for arg in args] - for key, val in kwargs.items(): - if val == "null": - kwargs[key] = None - super().__init__(*args, **kwargs) - - input_activations: Optional[QuantizationArgs] = Field( - default_factory=QuantizationArgs.default_activation_args, - description=( - "target quantization setting for input activations. Set to None to " - "not quantize input activations. Default is 8 bits asymmetric" - ), - ) - weights: Optional[QuantizationArgs] = Field( - default_factory=QuantizationArgs.default_weight_args, - description=( - "target quantization setting for model weights. Set to None to " - "not quantize weights. Default is 8 bits symmetric" - ), - ) - output_activations: Optional[QuantizationArgs] = Field( - default=None, - description=( - "target quantization setting for output activations. Set to None to " - "not quantize output activations. Default is None" - ), - ) - target_hardware: Optional[str] = Field( - default=None, - description=( - "target deployment runtime/hardware name to be set by default " - "classmethods. Default is None" - ), - ) - - @classmethod - def load( - cls, - scheme: QuantizationSchemeLoadable, - default: Optional["QuantizationScheme"] = None, - ) -> "QuantizationScheme": - """ - :param scheme: QuantizationScheme, dict representation of scheme, - or string alias of a scheme to load. Valid strings: - ['default', 'deepsparse', 'tensorrt'] - :param default: default QuantizationScheme to override 'default' scheme - with - :return: constructed QuantizationScheme object from the given scheme; - if given a dict, returns QuantizationScheme.parse_obj(scheme), string - input will return the defualt QuantizationScheme if set to 'default'. - """ - if isinstance(scheme, cls): - return scheme - elif scheme is None or scheme == "default": - # if no default override, defaults to QuantizationScheme() - return deepcopy(default) or cls() - elif isinstance(scheme, str): - if scheme == "deepsparse": - return cls.deepsparse() - elif scheme == "tensorrt": - return cls.tensorrt() - raise ValueError( - f"Unrecognized QuantizationScheme string alias {scheme}. " - "Valid strings: ['default', 'deepsparse', 'tensorrt']" - ) - elif isinstance(scheme, dict): - # default to dict - scheme = {key: _parse_quantization_arg(arg) for key, arg in scheme.items()} - return cls.parse_obj(scheme) - else: - raise ValueError( - f"Unrecognized type {type(scheme)} for QuantizationScheme.load, " - "expected one of: [QuantizationScheme, Dict, str, None]" - ) - - @classmethod - def deepsparse(cls) -> "QuantizationScheme": - """ - :return: QuantizationScheme for deepsparse targeted deployments - - int8, symmetric weights, asymmetric inputs, no output quantization - """ - return cls( - input_activations=QuantizationArgs(num_bits=8, symmetric=False), - weights=QuantizationArgs(num_bits=8, symmetric=True), - output_activations=None, - target_hardware="deepsparse", - ) - - @classmethod - def tensorrt(cls) -> "QuantizationScheme": - """ - :return: QuantizationScheme for tensorrt targeted deployments - - compatibility with explict quantization as supported by TensorRT 8.2: - int8, symmetric for both weights and inputs, no output quantization - """ - return cls( - input_activations=QuantizationArgs(num_bits=8, symmetric=True), - weights=QuantizationArgs(num_bits=8, symmetric=True), - output_activations=None, - target_hardware="tensorrt", - ) - - def get_qconfig(self) -> "torch.quantization.QConfig": - """ - :return: QConfig for Modules (output activations used, - use QuantWrapper for inputs) - """ - qconfig = _get_qconfig(self.output_activations, self.weights) - # add reference to this quantization scheme for reference - qconfig.quantization_scheme = self - return qconfig - - def get_wrapper_qconfig(self) -> "torch.quantization.QConfig": - """ - :return: QConfig for QuantWrapper objects (input activations used) - """ - qconfig = _get_qconfig(self.input_activations, None) - # add reference to this quantization scheme for reference - qconfig.quantization_scheme = self - return qconfig - - def __str__(self) -> str: - """ - :return: YAML friendly string serialization - """ - dict_repr = self.dict() - dict_repr = { - key: val if val is not None else "null" for key, val in dict_repr.items() - } - return str(dict_repr) - - -def compute_range(dtype: torch.dtype, bits: int): - """ - compute quantization limits depending on data type and number of bits - - :param dtype: data type. - :param bits: number of bits. - :return: minimum limit, maximum limit, whether the range is customized - """ - bits = bits if bits else 8 - is_custom = bits != 8 - if dtype == torch.qint8: - quant_min = -(2 ** (bits - 1)) - quant_max = (2 ** (bits - 1)) - 1 - elif dtype == torch.quint8: - quant_min = 0 - quant_max = (2**bits) - 1 - - return quant_min, quant_max, is_custom - - -# def get_observer( -# symmetric: bool, -# strategy: str, -# dtype: torch.dtype, -# bits: int, -# reduce_range: bool, -# qconfig_kwargs: Dict[str, Any], -# ): -# quant_min, quant_max, is_custom_qrange = compute_range(dtype, bits) - -# if strategy == "channel": -# qscheme = torch.per_channel_symmetric if symmetric else torch.per_channel_affine -# observer_cls = torch_quantization.MovingAveragePerChannelMinMaxObserver -# observer_kwargs = dict( -# ch_axis=0, -# dtype=dtype, -# qscheme=qscheme, -# reduce_range=reduce_range, -# ) -# else: # default to tensor strategy -# qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine -# observer_cls = torch_quantization.MovingAverageMinMaxObserver -# observer_kwargs = dict( -# dtype=dtype, -# qscheme=qscheme, -# reduce_range=reduce_range, -# ) -# """ -# in torch 1.9.1, quant_min and quant_max are not passed to observer: -# https://github.com/pytorch/pytorch/blob/v1.9.1/torch/quantization/fake_quantize.py#L109 -# however in 1.12.0, this is fixed so both are passed to observer: -# https://github.com/pytorch/pytorch/blob/v1.12.1/torch/ao/quantization/fake_quantize.py#L132 - -# Passing quant_min/quant_max to observer means the observer will have -# `self.has_customized_qrange == True` in both 1.9.1 and 1.12.0. - -# For whatever reason, both versions calculate zero point for -# quint8 differently **if there is a customized_qrange** -# 1. customized qrange has zero point of 127 -# 2. non-customized has zero point of 128. -# source: -# https://github.com/pytorch/pytorch/blob/v1.12.1/torch/ao/quantization/observer.py#L293 - -# **we want to ensure that the zero point is 128** -# see https://github.com/neuralmagic/sparseml/pull/604 -# """ -# if is_custom_qrange: -# # for both versions we need to include the custom min/max values in kwargs -# observer_kwargs["quant_min"] = quant_min -# observer_kwargs["quant_max"] = quant_max -# if _TORCH_PRE_112: -# # pre 1.12, the observer doesn't get passed the quant_min/quant_max values, -# # so we patch them in to the constructor of the observer -# observer_cls = partial( -# observer_cls, quant_min=quant_min, quant_max=quant_max -# ) -# else: -# # if using a non custom qrange, we can rely on default values used by -# # the observers -# if _TORCH_PRE_112: -# # pre 1.12, the observer doesn't get passed the quant_min/quant_max values, -# # so we are safe to pass these to FakeQuantize -# observer_kwargs["quant_min"] = quant_min -# observer_kwargs["quant_max"] = quant_max -# else: -# # post 1.12 we cannot pass them to the observer since that will set -# # has_customized_qrange. instead we rely on the default values -# # being equal to the `quant_min` and `quant_max` here. -# pass - -# observer_kwargs["observer"] = observer_cls -# observer_kwargs.update(qconfig_kwargs or {}) -# observer = FakeQuantizeWrapper.with_args(**observer_kwargs) - -# return observer - - -def _get_qconfig( - activation_args: Optional[QuantizationArgs], weight_args: Optional[QuantizationArgs] -) -> "torch.quantization.QConfig": - return torch_quantization.QConfig( - activation=activation_args.get_observer() if activation_args else Identity, - weight=weight_args.get_observer() if weight_args else Identity, - ) - - -def _parse_quantization_arg(arg: Any): - if arg == "None": - return None - return arg \ No newline at end of file From 971d1408fbf230014651c74524403568cb9d3b62 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Fri, 12 Apr 2024 18:03:36 +0000 Subject: [PATCH 4/6] before tests, correctness verified --- bin/quant.py | 60 ------------------- .../quantization/lifecycle/forward.py | 33 ---------- .../quantization/lifecycle/initialize.py | 5 -- .../quantization/observers/base.py | 1 - .../quantization/observers/memoryless.py | 3 - src/sparsetensors/quantization/quant_args.py | 3 +- 6 files changed, 2 insertions(+), 103 deletions(-) delete mode 100644 bin/quant.py diff --git a/bin/quant.py b/bin/quant.py deleted file mode 100644 index 94bfc448..00000000 --- a/bin/quant.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch -from torch.nn import Linear - -from sparsetensors.quantization.quant_args import QuantizationArgs -from sparsetensors.quantization.quant_scheme import QuantizationScheme -from sparsetensors.quantization.lifecycle.initialize import initialize_module_for_quantization -from sparsetensors.quantization.lifecycle.calibration import set_module_for_calibration -from sparsetensors.quantization.lifecycle.frozen import freeze_module_quantization -num_bits = 8 - -scheme = QuantizationScheme( - input_acivations=QuantizationArgs(num_bits=num_bits, symmetric=False), - weights=QuantizationArgs(num_bits=num_bits, symmetric=True), - output_activations=None, - targets = ["*"], -) - -layer = Linear(4, 4) -print(layer) -print(dict(layer.named_parameters())) - - -initialize_module_for_quantization(layer, scheme) -print(layer) # should see observer under layer now -print(0) -print(dict(layer.named_parameters())) # should see empty tensors for scale and zero point now -print(1) - - -set_module_for_calibration(layer) -# do a calibration step -layer(torch.randn(4,4)) -print(dict(layer.named_parameters())) # scale and zero point should have updated values -print(2) -print("calib layers ") -for i in range(10): - print("iter", i) - layer(torch.randn(4,4)) -print(dict(layer.named_parameters())) # scale and zero point should have updated values again since we did another pass - -print(3) -# breakpoint() - - -freeze_module_quantization(layer) -print("freeze layers ") -for i in range(10): - # do more forward passes but show args are frozen - print("iter", i) - layer(torch.randn(4,4)) -print(dict(layer.named_parameters())) # scale and zero point should not be updated now - - -# # missing - -# # correctness -# # quantizing an entire model - - - diff --git a/src/sparsetensors/quantization/lifecycle/forward.py b/src/sparsetensors/quantization/lifecycle/forward.py index cbb27dea..ab20e29b 100644 --- a/src/sparsetensors/quantization/lifecycle/forward.py +++ b/src/sparsetensors/quantization/lifecycle/forward.py @@ -16,11 +16,6 @@ import torch from sparsetensors.quantization.lifecycle.status import QuantizationStatus - -# from sparsetensors.quantization.utils.quantization_scheme import ( -# QuantizationArgs, -# QuantizationScheme, -# ) from sparsetensors.quantization.quant_args import QuantizationArgs from sparsetensors.quantization.quant_scheme import QuantizationScheme from torch.nn import Module @@ -59,35 +54,7 @@ def fake_quantize( args: QuantizationArgs, ) -> torch.Tensor: max_q = torch.tensor(2**args.num_bits - 1) - columns = x.shape[1] Q = torch.zeros_like(x) - # for i1 in range(0, columns, args.block_size): - # i2 = min(i1 + args.block_size, columns) - # count = i2 - i1 - - # W1 = x[:, i1:i2].clone() - # Q1 = torch.zeros_like(W1) - - # for i in range(count): - # w = W1[:, i] - # breakpoint() - # if args.group_size != -1: - # if (i1 + i) % args.group_size == 0: - # xmin, xmax = get_qparams( - # x[:, (i1 + i) : (i1 + i + args.group_size)], args.symmetric - # ) - # scale, zero = get_scale_zero_point( - # x[:, (i1 + i) : (i1 + i + args.group_size)], - # max_q, - # xmax, - # xmin, - # args.symmetric, - # args.group_size, - # ) - - # q = quantize(w.unsqueeze(1), scale, zero, max_q).flatten() - # Q1[:, i] = q - # Q[:, i1:i2] = Q1 Q = quantize(x, scale, zero_point, max_q) return dequantize(Q, scale, zero_point) diff --git a/src/sparsetensors/quantization/lifecycle/initialize.py b/src/sparsetensors/quantization/lifecycle/initialize.py index 6d23f4cc..5661fdbd 100644 --- a/src/sparsetensors/quantization/lifecycle/initialize.py +++ b/src/sparsetensors/quantization/lifecycle/initialize.py @@ -18,11 +18,6 @@ import torch from sparsetensors.quantization.lifecycle.forward import wrap_module_forward_quantized from sparsetensors.quantization.lifecycle.status import QuantizationStatus - -# from sparsetensors.quantization.utils.quantization_scheme import ( -# QuantizationArgs, -# QuantizationScheme, -# ) from sparsetensors.quantization.quant_args import QuantizationArgs from sparsetensors.quantization.quant_scheme import QuantizationScheme from torch.nn import Module, Parameter diff --git a/src/sparsetensors/quantization/observers/base.py b/src/sparsetensors/quantization/observers/base.py index 00cd7561..a3184096 100644 --- a/src/sparsetensors/quantization/observers/base.py +++ b/src/sparsetensors/quantization/observers/base.py @@ -14,7 +14,6 @@ from typing import Optional, Tuple -# from sparsetensors.quantization.utils.quantization_scheme import QuantizationArgs from sparsetensors.quantization.quant_args import QuantizationArgs from sparsezoo.utils.registry import RegistryMixin from torch import FloatTensor, IntTensor, Tensor diff --git a/src/sparsetensors/quantization/observers/memoryless.py b/src/sparsetensors/quantization/observers/memoryless.py index faabbb5a..b69c841d 100644 --- a/src/sparsetensors/quantization/observers/memoryless.py +++ b/src/sparsetensors/quantization/observers/memoryless.py @@ -19,9 +19,6 @@ from torch import FloatTensor, IntTensor, Tensor -# from sparsetensors.quantization.utils.quantization_scheme import QuantizationArgs - - __all__ = ["MemorylessObserver"] diff --git a/src/sparsetensors/quantization/quant_args.py b/src/sparsetensors/quantization/quant_args.py index fb9e9b01..d90fe9bc 100644 --- a/src/sparsetensors/quantization/quant_args.py +++ b/src/sparsetensors/quantization/quant_args.py @@ -64,7 +64,8 @@ class QuantizationArgs(BaseModel): observer: str = Field( default="minmax", description=( - "The class to use to compute the quantization params - scale and zero-point'" + "The class to use to compute the quantization param - " + "scale and zero-point'" ), ) observer_kwargs: Dict[str, Any] = Field( From b195681148abc8d488cdbbbd2ccb417a49dd545b Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Fri, 12 Apr 2024 18:13:23 +0000 Subject: [PATCH 5/6] specify sparszoo version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index de506c99..c2630095 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ def _setup_install_requires() -> List: return ["torch>=1.7.0", "transformers<=4.40", "pydantic<2.7"] def _setup_extras() -> Dict: - return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0", "sparsezoo"]} + return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0", "sparsezoo==1.7.0"]} setup( name="sparsetensors", From 045841772732bc784c03be036059cec5cf28a3cf Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Fri, 12 Apr 2024 18:26:34 +0000 Subject: [PATCH 6/6] remove sparsezoo --- setup.py | 2 +- src/sparsetensors/quantization/observers/base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index c2630095..89a17ad7 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ def _setup_install_requires() -> List: return ["torch>=1.7.0", "transformers<=4.40", "pydantic<2.7"] def _setup_extras() -> Dict: - return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0", "sparsezoo==1.7.0"]} + return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0",]} setup( name="sparsetensors", diff --git a/src/sparsetensors/quantization/observers/base.py b/src/sparsetensors/quantization/observers/base.py index a3184096..52a464b9 100644 --- a/src/sparsetensors/quantization/observers/base.py +++ b/src/sparsetensors/quantization/observers/base.py @@ -15,7 +15,7 @@ from typing import Optional, Tuple from sparsetensors.quantization.quant_args import QuantizationArgs -from sparsezoo.utils.registry import RegistryMixin +from sparsetensors.registry.registry import RegistryMixin from torch import FloatTensor, IntTensor, Tensor from torch.nn import Module