From 6e0c4fcaeb16d14ff3b7926efe611549aba053dd Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 2 Sep 2024 00:10:53 +0100 Subject: [PATCH 1/3] HQO --- src/brevitas/core/stats/stats_op.py | 222 ++++++++++++++++++ src/brevitas/core/zero_point.py | 2 +- src/brevitas/quant/base.py | 43 +++- src/brevitas/quant/scaled_int.py | 25 ++ src/brevitas/quant/shifted_scaled_int.py | 69 +++++- .../common/generative/quant_blocks.py | 1 - .../common/generative/quantize.py | 20 +- .../common/generative/quantizers.py | 12 +- .../imagenet_classification/ptq/ptq_common.py | 23 +- .../ptq/ptq_evaluate.py | 2 +- src/brevitas_examples/llm/main.py | 2 +- tests/brevitas_examples/test_llm.py | 4 + 12 files changed, 399 insertions(+), 26 deletions(-) diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index fac729326..461aeb3e6 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -10,6 +10,8 @@ import brevitas from brevitas import config +from brevitas.core.function_wrapper.misc import Identity +from brevitas.core.function_wrapper.ops_ste import ScalarClampMinSte from brevitas.core.utils import StatelessBuffer from brevitas.function.ops import max_int from brevitas.quant_tensor import _unpack_quant_tensor @@ -544,3 +546,223 @@ def forward(self, x): x = self.input_view_shape_impl(x) self.internal_candidate = self.mse_init_op(x).detach() return self.internal_candidate + + +class HalfQuadraticOptimizerScale(torch.nn.Module): + # References: + # https://mobiusml.github.io/hqq_blog/ + # https://github.com/mobiusml/hqq?tab=readme-ov-file + + def __init__( + self, + proxy_module, + hqo_init_op_scale, + keepdim: bool, + inner_stats_input_view_shape_impl: torch.nn.Module, + scaling_min_val: Optional[float] = None, + stats_reduce_dim: Optional[int] = None, + int_scaling_impl=None, + bit_width_impl=None, + hqo_beta_scale: float = 1e5, + hqo_kappa_scale: float = 1.01, + hqo_lp_norm_scale: float = .7, + hqo_iters_scale: int = 1000): + super(HalfQuadraticOptimizerScale, self).__init__() + self.hqo_init_op = hqo_init_op_scale + self.input_view_shape_impl = inner_stats_input_view_shape_impl + self.proxy_forward = proxy_module.forward + self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled) + self.internal_candidate = None + self.hqo_iters = hqo_iters_scale + self.stats_reduce_dim = stats_reduce_dim + self.local_loss_mode: bool = False + + self.beta = hqo_beta_scale + self.kappa = hqo_kappa_scale + self.lp_norm = hqo_lp_norm_scale + + self.int_scaling_impl = int_scaling_impl + self.msb_clamp_bit_width_impl = bit_width_impl + if scaling_min_val is not None and scaling_min_val != 0: + self.clamp_min_ste = ScalarClampMinSte(scaling_min_val) + else: + self.clamp_min_ste = Identity() + self.keepdim = keepdim + + def parameter_search(self, xl, x): + best_loss = torch.tensor(float('inf'), device=x.device, dtype=x.dtype) + candidate = xl + best_candidate = candidate + beta = self.beta + with torch.no_grad(): + for i in range(0, self.hqo_iters): + self.internal_candidate = candidate + self.set_local_loss_mode(True) + quant_tensor = self.proxy_forward(x).detach() + self.set_local_loss_mode(False) + loss = torch.abs(quant_tensor.value - x).mean() + + best_candidate = torch.where(loss < best_loss, candidate, best_candidate) + if loss >= best_loss: + break + best_loss = torch.min(loss, best_loss) + W_e = shrink_lp_op(x - quant_tensor.value, beta, self.lp_norm) + zero_point = quant_tensor.zero_point + num = self.input_view_shape_impl(x - W_e).detach() + den = self.input_view_shape_impl( + torch.round(quant_tensor.value / quant_tensor.scale) - zero_point).detach() + mask = (num != 0.) & (den != 0.) + if self.stats_reduce_dim is None: + candidate = masked_median(num / den, mask) + else: + candidate = masked_median( + num / den, mask, dim=self.stats_reduce_dim, keepdim=self.keepdim) + candidate = candidate.type_as(self.internal_candidate) + candidate = self.clamp_min_ste(candidate) + bit_width = self.msb_clamp_bit_width_impl() + int_threshold = self.int_scaling_impl(bit_width) + candidate = candidate * int_threshold + candidate[torch.isnan(candidate)] = self.internal_candidate[torch.isnan(candidate)] + candidate[torch.isinf(candidate)] = self.internal_candidate[torch.isinf(candidate)] + beta *= self.kappa + return best_candidate + + def optimize(self, x): + x_view = self.input_view_shape_impl(x) + + init = self.hqo_init_op(x_view).detach() + best_candidate = self.parameter_search(init, x_view) + + # Save for evaluation by other modules (e.g. zp) invoking local loss mode + self.internal_candidate = best_candidate.detach() + torch.cuda.empty_cache() + return best_candidate + + def forward(self, x): + if not self.local_loss_mode: + with torch.no_grad(): + return self.optimize(x) + else: + # This is invoked for the zero-point whenever scale is being optimized first + if self.internal_candidate is None: + x = self.input_view_shape_impl(x) + self.internal_candidate = self.hqo_init_op(x).detach() + return self.internal_candidate + + +class HalfQuadraticOptimizerZeroPoint(torch.nn.Module): + # References: + # https://mobiusml.github.io/hqq_blog/ + # https://github.com/mobiusml/hqq?tab=readme-ov-file + + def __init__( + self, + proxy_module, + keepdim: bool, + hqo_init_op_zp: torch.nn.Module, + inner_stats_input_view_shape_impl: torch.nn.Module, + stats_reduce_dim: Optional[int] = None, + hqo_beta_zp: float = 1e0, + hqo_kappa_zp: float = 1.01, + hqo_lp_norm_zp: float = .5, + hqo_iters_zp: int = 1000): + super(HalfQuadraticOptimizerZeroPoint, self).__init__() + self.hqo_init_op_zp = hqo_init_op_zp + self.input_view_shape_impl = inner_stats_input_view_shape_impl + self.proxy_forward = proxy_module.forward + self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled) + self.internal_candidate = None + self.stats_reduce_dim = stats_reduce_dim + self.local_loss_mode: bool = False + self.beta = hqo_beta_zp + self.kappa = hqo_kappa_zp + self.lp_norm = hqo_lp_norm_zp + self.hqo_iters = hqo_iters_zp + self.keepdim = keepdim + + def parameter_search(self, xl, x): + best_loss = torch.tensor(float('inf'), device=x.device, dtype=x.dtype) + candidate = xl + best_candidate = candidate + with torch.no_grad(): + for i in range(0, self.hqo_iters): + self.internal_candidate = candidate + self.set_local_loss_mode(True) + quant_tensor = self.proxy_forward(x).detach() + self.set_local_loss_mode(False) + qt_value = self.input_view_shape_impl(quant_tensor.value) + qt_scale = self.input_view_shape_impl(quant_tensor.scale) + qt_int = self.input_view_shape_impl(quant_tensor.int()) + loss = torch.abs(qt_value - x).mean() + best_candidate = torch.where(loss < best_loss, candidate, best_candidate) + if loss >= best_loss: + break + best_loss = torch.min(loss, best_loss) + W_e = shrink_lp_op(x - qt_value, self.beta, self.lp_norm) + + val = self.input_view_shape_impl((x - W_e) - qt_int * qt_scale) + + if self.stats_reduce_dim is None: + candidate = torch.mean(val) + else: + candidate = torch.mean(val, dim=self.stats_reduce_dim, keepdim=self.keepdim) + self.beta *= self.kappa + return best_candidate + + def optimize(self, x): + x_view = self.input_view_shape_impl(x) + + init = self.hqo_init_op_zp(x_view).detach() + + best_candidate = self.parameter_search(init, x) + + # Save for evaluation by other modules (e.g. zp) invoking local loss mode + self.internal_candidate = best_candidate.detach() + torch.cuda.empty_cache() + return best_candidate + + def forward(self, x): + if not self.local_loss_mode: + with torch.no_grad(): + return self.optimize(x) + else: + # This is invoked for the zero-point whenever scale is being optimized first + if self.internal_candidate is None: + x = self.input_view_shape_impl(x) + self.internal_candidate = self.hqo_init_op_zp(x).detach() + return self.internal_candidate + + +def masked_median(x, mask, dim=None, keepdim=False): + """Compute the median of tensor x along dim, ignoring values where mask is False. + x and mask need to be broadcastable. + + Args: + x (Tensor): Tensor to compute median of. + mask (BoolTensor): Same shape as x with True where x is valid and False + where x should be masked. Mask should not be all False in any column of + dimension dim to avoid NaNs from zero division. + dim (int, optional): Dimension to take median of. Defaults to 0. + + Returns: + Tensor: Same shape as x, except dimension dim reduced. + """ + # uncomment this assert for safety but might impact performance + # assert ( + # mask.sum(dim=dim).ne(0).all() + # ), "mask should not be all False in any column, causes zero division" + x_nan = x.float().masked_fill(~mask, float("nan")) + if dim is None: + x_median = x_nan.nanmedian() + else: + x_median, _ = x_nan.nanmedian(dim=dim, keepdim=keepdim) + return x_median + + +# Shrinking operator +def shrink_lp_op(x: Tensor, beta: float, lp_norm: float) -> Tensor: + if lp_norm == 1: + return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta) + else: + return torch.sign(x) * torch.nn.functional.relu( + torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x), lp_norm - 1)) diff --git a/src/brevitas/core/zero_point.py b/src/brevitas/core/zero_point.py index 872435ec7..3f80f1dd4 100644 --- a/src/brevitas/core/zero_point.py +++ b/src/brevitas/core/zero_point.py @@ -281,7 +281,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): output_dict = super(ParameterFromStatsFromParameterZeroPoint, self).state_dict( destination=destination, prefix=prefix, keep_vars=keep_vars) # Avoid saving the init value - if not self.init_done: + if not self.init_done and not config._FULL_STATE_DICT: del output_dict[prefix + 'value'] return output_dict diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index 7b6fe409e..92f41b990 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -37,6 +37,8 @@ from brevitas.core.stats import MSE from brevitas.core.stats import NegativeMinOrZero from brevitas.core.stats import NegativePercentileOrZero +from brevitas.core.stats.stats_op import HalfQuadraticOptimizerScale +from brevitas.core.stats.stats_op import HalfQuadraticOptimizerZeroPoint from brevitas.core.utils import SingleArgStatelessBuffer from brevitas.core.zero_point import ParameterFromRuntimeZeroPoint from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint @@ -458,7 +460,7 @@ class MSEAsymmetricScaleSubInjector(MSESubInjectorBase): stats_impl = MSE stats_reduce_dim = (this << 1).stats_reduce_dim device = (this << 1).device - type = (this << 1).type + dtype = (this << 1).dtype class MSEZeroPointSubInjector(MSESubInjectorBase): @@ -470,7 +472,7 @@ class MSEZeroPointSubInjector(MSESubInjectorBase): stats_impl = MSE stats_reduce_dim = (this << 1).stats_reduce_dim device = (this << 1).device - type = (this << 1).type + dtype = (this << 1).dtype class MSEAsymmetricScale(ExtendedInjector): @@ -520,3 +522,40 @@ class MSEWeightZeroPoint(MSEZeroPoint): class MSEActZeroPoint(MSEZeroPoint): zero_point_impl = ParameterFromRuntimeZeroPoint + + +class HQOZeroPoint(ExtendedInjector): + + hqo_init_op_zp = NegativeMinOrZero + inner_stats_input_view_shape_impl = this.zero_point_stats_input_view_shape_impl + stats_impl_zp = HalfQuadraticOptimizerZeroPoint + + @value + def zero_point_stats_impl(): + return this.stats_impl_zp + + +class HQOScale(ExtendedInjector): + scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS + inner_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl + stats_impl_scale = HalfQuadraticOptimizerScale + + @value + def scaling_stats_impl(): + return this.stats_impl_scale + + +class HQOAsymmetricScale(HQOScale): + hqo_init_op_scale = AbsMinMax + + +class HQOSymmetricScale(HQOScale): + hqo_init_op_scale = AbsMax + + +class HQOActZeroPoint(HQOZeroPoint): + zero_point_impl = ParameterFromRuntimeZeroPoint + + +class HQOWeightZeroPoint(HQOZeroPoint): + zero_point_impl = ParameterFromStatsFromParameterZeroPoint diff --git a/src/brevitas/quant/scaled_int.py b/src/brevitas/quant/scaled_int.py index 0f67300c3..b5f9174d7 100644 --- a/src/brevitas/quant/scaled_int.py +++ b/src/brevitas/quant/scaled_int.py @@ -3,6 +3,7 @@ from brevitas.core.function_wrapper import TensorClamp from brevitas.quant.base import * +from brevitas.quant.base import HQOSymmetricScale from brevitas.quant.solver.act import ActQuantSolver from brevitas.quant.solver.bias import BiasQuantSolver from brevitas.quant.solver.trunc import TruncQuantSolver @@ -443,3 +444,27 @@ class Int8AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareZeroCenterWeight >>> conv.quant_weight() """ bit_width = 8 + + +class Int8WeightPerTensorFloatHQO(HQOSymmetricScale, Int8WeightPerTensorFloat): + """ + 8-bit narrow per-tensor signed int weight quantizer with per-tensor floating-point scale factor computed + from HQO local loss. + + Examples: + >>> from brevitas.nn import QuantLinear + >>> fc = QuantLinear(10, 5, bias=False, weight_quant=Int8WeightPerTensorFloatHQO) + """ + pass + + +class Int8WeightPerChannelFloatHQO(HQOSymmetricScale, Int8WeightPerChannelFloat): + """ + 8-bit narrow per-tensor signed int weight quantizer with per-tensor floating-point scale factor computed + from HQO local loss. + + Examples: + >>> from brevitas.nn import QuantLinear + >>> fc = QuantLinear(10, 5, bias=False, weight_quant=Int8WeightPerChannelFloatHQO) + """ + pass diff --git a/src/brevitas/quant/shifted_scaled_int.py b/src/brevitas/quant/shifted_scaled_int.py index 936737571..d18150a10 100644 --- a/src/brevitas/quant/shifted_scaled_int.py +++ b/src/brevitas/quant/shifted_scaled_int.py @@ -1,10 +1,12 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from brevitas.inject.enum import ScalingPerOutputType +from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector from brevitas.quant.base import * +from brevitas.quant.base import HQOActZeroPoint +from brevitas.quant.base import HQOZeroPoint from brevitas.quant.solver.act import ActQuantSolver -from brevitas.quant.solver.bias import BiasQuantSolver -from brevitas.quant.solver.trunc import TruncQuantSolver from brevitas.quant.solver.weight import WeightQuantSolver __all__ = [ @@ -15,7 +17,10 @@ 'ShiftedUint8ActPerTensorFixedPointMSE', 'ShiftedUint8ActPerTensorFloatMSE', 'ShiftedUint8WeightPerTensorFloatMSE', - 'ShiftedUint8WeightPerChannelFloatMSE'] + 'ShiftedUint8WeightPerChannelFloatMSE', + 'ShiftedUint8ActPerTensorFloatHQO', + 'ShiftedUint8WeightPerChannelFloatHQO', + 'ShiftedUint8WeightPerTensorFloatHQO'] class ShiftedUint8ActPerTensorFixedPoint(ShiftedParamFromPercentileUintQuant, @@ -138,3 +143,61 @@ class ShiftedUint8WeightPerChannelFloatMSE(MSEAsymmetricScale, >>> fc = QuantLinear(10, 5, bias=False, weight_quant=ShiftedUint8WeightPerChannelFloat) """ pass + + +class ShiftedUint8WeightPerTensorFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerTensorFloat): + """ + 8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer + zero point. Zero-point is initialized from HQO local loss. + + Examples: + >>> from brevitas.nn import QuantLinear + >>> fc = QuantLinear(10, 5, bias=False, weight_quant=ShiftedUint8WeightPerTensorFloatHQO) + """ + quantize_zero_point = False + + +class ShiftedUint8WeightPerChannelFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerChannelFloat): + """ + 8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer + zero point. Zero-point is initialized from HQO local loss. + + Examples: + >>> from brevitas.nn import QuantLinear + >>> fc = QuantLinear(10, 5, bias=False, weight_quant=ShiftedUint8WeightPerChannelFloatHQO) + """ + quantize_zero_point = False + + +class ShiftedUint8WeightPerGroupFloatHQO(ShiftedUint8WeightPerChannelFloatHQO): + """ + 8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer + zero point.Zero-point is initialized from HQO local loss. + Examples: + >>> from brevitas.nn import QuantLinear + >>> fc = QuantLinear(10, 5, bias=False, weight_quant=ShiftedUint8WeightPerChannelFloatHQO) + """ + group_size = 32 + scaling_per_output_type = ScalingPerOutputType.GROUP + proxy_class = GroupwiseWeightQuantProxyFromInjector + + +class ShiftedUint8ActPerTensorFloatHQO(HQOActZeroPoint, ShiftedUint8ActPerTensorFloat): + """ + 8-bit per-tensor unsigned int activations quantizer with floating-point scale factor and + integer zero point. Zero-point is learned parameter initialized from + HQO local loss. + + Examples: + >>> from brevitas.nn import QuantReLU + >>> act = QuantReLU(act_quant=ShiftedUint8ActPerTensorFloatHQO) + """ + quantize_zero_point = False + + +class ShiftedUint8WeightGroupQuantFloat(ShiftedUint8WeightPerChannelFloat): + """ + Block / group / vector signed asymmetric weight quantizer with float scales and zero-points. + """ + proxy_class = GroupwiseWeightQuantProxyFromInjector + scaling_per_output_type = ScalingPerOutputType.GROUP diff --git a/src/brevitas_examples/common/generative/quant_blocks.py b/src/brevitas_examples/common/generative/quant_blocks.py index 696340a2c..93cc235e2 100644 --- a/src/brevitas_examples/common/generative/quant_blocks.py +++ b/src/brevitas_examples/common/generative/quant_blocks.py @@ -5,7 +5,6 @@ from typing import Callable -import torch from torch import Tensor import torch.nn as nn diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 73831d5fa..10f7ce259 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -37,14 +37,20 @@ from brevitas.quant.scaled_int import Int8ActPerTensorFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloatMSE from brevitas.quant.scaled_int import Int8WeightPerChannelFloat +from brevitas.quant.scaled_int import Int8WeightPerChannelFloatHQO from brevitas.quant.scaled_int import Int8WeightPerChannelFloatMSE from brevitas.quant.scaled_int import Int8WeightPerTensorFloat +from brevitas.quant.scaled_int import Int8WeightPerTensorFloatHQO from brevitas.quant.scaled_int import Int8WeightPerTensorFloatMSE from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatMSE +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightGroupQuantFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatMSE +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerGroupFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear @@ -56,7 +62,6 @@ from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerRowFloat from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat -from brevitas_examples.common.generative.quantizers import ShiftedUintWeightAsymmetricGroupQuant WEIGHT_QUANT_MAP = { 'int': { @@ -68,14 +73,23 @@ 'sym': Int8WeightPerChannelFloat, 'asym': ShiftedUint8WeightPerChannelFloat}, 'per_group': { 'sym': IntWeightSymmetricGroupQuant, - 'asym': ShiftedUintWeightAsymmetricGroupQuant}}, + 'asym': ShiftedUint8WeightGroupQuantFloat}}, 'mse': { 'per_tensor': { 'sym': Int8WeightPerTensorFloatMSE, 'asym': ShiftedUint8WeightPerTensorFloatMSE}, 'per_channel': { 'sym': Int8WeightPerChannelFloatMSE, - 'asym': ShiftedUint8WeightPerChannelFloatMSE},},}, + 'asym': ShiftedUint8WeightPerChannelFloatMSE}}, + 'hqo': { + 'per_tensor': { + 'sym': Int8WeightPerTensorFloatHQO, + 'asym': ShiftedUint8WeightPerTensorFloatHQO}, + 'per_channel': { + 'sym': Int8WeightPerChannelFloatHQO, + 'asym': ShiftedUint8WeightPerChannelFloatHQO}, + 'per_group': { + 'asym': ShiftedUint8WeightPerGroupFloatHQO}},}, 'po2_scale': { 'stats': { 'per_tensor': { diff --git a/src/brevitas_examples/common/generative/quantizers.py b/src/brevitas_examples/common/generative/quantizers.py index 1f41e136a..c3c99a96f 100644 --- a/src/brevitas_examples/common/generative/quantizers.py +++ b/src/brevitas_examples/common/generative/quantizers.py @@ -10,7 +10,9 @@ from brevitas.core.scaling.runtime import RuntimeDynamicGroupStatsScaling from brevitas.core.stats import AbsMinMax from brevitas.core.stats import NegativeMinOrZero +from brevitas.core.stats.stats_op import HalfQuadraticOptimizerZeroPoint from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE +from brevitas.core.zero_point import StatsFromParameterZeroPoint from brevitas.inject import ExtendedInjector from brevitas.inject import this from brevitas.inject import value @@ -21,11 +23,13 @@ from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector from brevitas.proxy.groupwise_int_runtime_quant import GroupwiseActQuantProxyFromInjector from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector +from brevitas.quant.base import HQOWeightZeroPoint from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloat from brevitas.quant.scaled_int import Int8WeightPerChannelFloat +from brevitas.quant.scaled_int import Int8WeightPerChannelFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat @@ -54,14 +58,6 @@ class Fp8e4m3WeightSymmetricGroupQuant(Fp8e4m3WeightPerChannelFloat): scaling_per_output_type = ScalingPerOutputType.GROUP -class ShiftedUintWeightAsymmetricGroupQuant(ShiftedUint8WeightPerChannelFloat): - """ - Block / group / vector signed asymmetric weight quantizer with float scales and zero-points. - """ - proxy_class = GroupwiseWeightQuantProxyFromInjector - scaling_per_output_type = ScalingPerOutputType.GROUP - - class Int8DynamicActPerTensorFloat(DynamicActProxyMixin, Int8ActPerTensorFloat): """ Symmetric quantizer with per tensor dynamic scale. diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index bac596be5..83ab25b69 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -22,7 +22,6 @@ from brevitas.graph.target.flexml import quantize_flexml from brevitas.inject import value import brevitas.nn as qnn -from brevitas.quant.experimental.float import Fp8e4m3Act from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloatMSE from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat @@ -52,17 +51,22 @@ from brevitas.quant.scaled_int import Int8ActPerTensorFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloatMSE from brevitas.quant.scaled_int import Int8WeightPerChannelFloat +from brevitas.quant.scaled_int import Int8WeightPerChannelFloatHQO from brevitas.quant.scaled_int import Int8WeightPerChannelFloatMSE from brevitas.quant.scaled_int import Int8WeightPerTensorFloat +from brevitas.quant.scaled_int import Int8WeightPerTensorFloatHQO from brevitas.quant.scaled_int import Int8WeightPerTensorFloatMSE from brevitas.quant.scaled_int import Int16Bias from brevitas.quant.scaled_int import Int32Bias from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFixedPoint from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat +from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatMSE from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatMSE from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat @@ -104,7 +108,14 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat): 'asym': ShiftedUint8WeightPerTensorFloatMSE}, 'per_channel': { 'sym': Int8WeightPerChannelFloatMSE, - 'asym': ShiftedUint8WeightPerChannelFloatMSE},},}, + 'asym': ShiftedUint8WeightPerChannelFloatMSE}}, + 'hqo': { + 'per_tensor': { + 'sym': Int8WeightPerTensorFloatHQO, + 'asym': ShiftedUint8WeightPerTensorFloatHQO}, + 'per_channel': { + 'sym': Int8WeightPerChannelFloatHQO, + 'asym': ShiftedUint8WeightPerChannelFloatHQO}}}, 'po2_scale': { 'stats': { 'per_tensor': { @@ -411,7 +422,9 @@ def kwargs_prefix(prefix, weight_kwargs): act_scale_type][act_param_method][act_quant_granularity]['sym'] act_quant = act_quant.let(**act_bit_width_dict) + act_quant = act_quant.let(**{'dtype': dtype, 'device': device}) sym_act_quant = sym_act_quant.let(**act_bit_width_dict) + sym_act_quant = sym_act_quant.let(**{'dtype': dtype, 'device': device}) else: act_quant = None sym_act_quant = None @@ -424,15 +437,13 @@ def kwargs_prefix(prefix, weight_kwargs): if weight_quant_type == 'asym': weight_quant = weight_quant.let(zero_point_impl=ParameterFromStatsFromParameterZeroPoint) if act_quant is not None: - act_quant = act_quant.let( - **{ - 'high_percentile_q': act_quant_percentile, 'dtype': dtype, 'device': device}) + act_quant = act_quant.let(**{'high_percentile_q': act_quant_percentile}) if act_quant_type == 'asym' and act_quant_percentile is not None: act_quant = act_quant.let(**{'low_percentile_q': 100 - act_quant_percentile}) if sym_act_quant is not None: sym_act_quant = sym_act_quant.let( **{ - 'high_percentile_q': act_quant_percentile, 'dtype': dtype, 'device': device}) + 'high_percentile_q': act_quant_percentile}) weight_quant_dict = {'weight_quant': weight_quant} diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index c960a89e6..58cc6563f 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -130,7 +130,7 @@ def parse_type(v, default_type): parser.add_argument( '--weight-quant-calibration-type', default='stats', - choices=['stats', 'mse'], + choices=['stats', 'mse', 'hqo'], help='Weight quantization calibration type (default: stats)') parser.add_argument( '--act-equalization', diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 6060ef498..e19390774 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -331,7 +331,7 @@ def parse_args(args): '--weight-param-method', type=str, default='stats', - choices=['stats', 'mse'], + choices=['stats', 'mse', 'hqo'], help='How scales/zero-point are determined. Default: stats.') parser.add_argument( '--weight-scale-precision', diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 6a98911a9..ee5db176f 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -139,6 +139,8 @@ def run_test_models_run_args(args, model_with_ppl): @pytest_cases.fixture( ids=[ "defaults", + "sym_weight_param_method=hqo", + "asym_weight_param_method=hqo", "bias_corr=True", "act_equalization=layerwise", "act_equalization=fx", @@ -147,6 +149,8 @@ def run_test_models_run_args(args, model_with_ppl): "ln_affine_merge=True",], params=[ {}, + {"weight_param_method": "hqo"}, + {"weight_param_method": "hqo", "weight_quant_type": "asym"}, {"bias_corr": True}, {"act_equalization": "layerwise"}, {"act_equalization": "fx"}, From 16d75d8713b2d797048243b3fda8122b6daba3df Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 11 Sep 2024 17:31:00 +0100 Subject: [PATCH 2/3] fix tests --- src/brevitas/quant_tensor/int_quant_tensor.py | 5 +++-- .../imagenet_classification/ptq/ptq_common.py | 4 +--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/brevitas/quant_tensor/int_quant_tensor.py b/src/brevitas/quant_tensor/int_quant_tensor.py index 3572a8900..f06e321b2 100644 --- a/src/brevitas/quant_tensor/int_quant_tensor.py +++ b/src/brevitas/quant_tensor/int_quant_tensor.py @@ -15,7 +15,7 @@ from .torch_handler import QUANT_TENSOR_FN_HANDLER IS_VALID_ATOL = 2e-1 -BFLOAT16_IS_VALID_ATOL = 0.5 +B_FLOAT16_IS_VALID_ATOL = 0.5 class IntQuantTensor(IntQuantTensorBase, QuantTensor): @@ -78,7 +78,8 @@ def is_valid(self): pre_round_int_value = self._pre_round_int_value rounded_int_value = torch.round(pre_round_int_value) max_abs_diff = torch.max(torch.abs(pre_round_int_value - rounded_int_value)) - atol = BFLOAT16_IS_VALID_ATOL if self.value.dtype == torch.bfloat16 else IS_VALID_ATOL + atol = B_FLOAT16_IS_VALID_ATOL if self.value.dtype in ( + torch.bfloat16, torch.float16) else IS_VALID_ATOL is_int = max_abs_diff < atol if self.bit_width >= 2: if self.signed: diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 83ab25b69..7d846ce8d 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -441,9 +441,7 @@ def kwargs_prefix(prefix, weight_kwargs): if act_quant_type == 'asym' and act_quant_percentile is not None: act_quant = act_quant.let(**{'low_percentile_q': 100 - act_quant_percentile}) if sym_act_quant is not None: - sym_act_quant = sym_act_quant.let( - **{ - 'high_percentile_q': act_quant_percentile}) + sym_act_quant = sym_act_quant.let(**{'high_percentile_q': act_quant_percentile}) weight_quant_dict = {'weight_quant': weight_quant} From 04a7de8bec7e4c4d2dd30c4aa03d64e27aa177bf Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 11 Sep 2024 19:03:35 +0100 Subject: [PATCH 3/3] Fix local loss tests + JIT --- noxfile.py | 2 +- tests/brevitas_examples/test_llm.py | 2 ++ tests/brevitas_examples/test_quantize_model.py | 3 +++ tests/marker.py | 10 ++++++++++ 4 files changed, 16 insertions(+), 1 deletion(-) diff --git a/noxfile.py b/noxfile.py index 17a38789d..8aad90528 100644 --- a/noxfile.py +++ b/noxfile.py @@ -120,7 +120,7 @@ def tests_brevitas_cpu(session, pytorch, jit_status): @nox.parametrize("pytorch", PYTORCH_VERSIONS, ids=PYTORCH_IDS) @nox.parametrize("jit_status", JIT_STATUSES, ids=JIT_IDS) def tests_brevitas_examples_cpu(session, pytorch, jit_status): - session.env['PYTORCH_JIT'] = '{}'.format(int(jit_status == 'jit_enabled')) + session.env['BREVITAS_JIT'] = '{}'.format(int(jit_status == 'jit_enabled')) install_pytorch(pytorch, session) install_torchvision(pytorch, session) # For CV eval scripts session.install('--upgrade', '.[test, tts, stt, vision]') diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index ee5db176f..90981df29 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -161,6 +161,8 @@ def run_test_models_run_args(args, model_with_ppl): def toggle_run_args(default_run_args, request): args = default_run_args args.update(**request.param) + if args.weight_param_method == 'hqo' and config.JIT_ENABLED: + pytest.skip("Local loss mode requires JIT to be disabled") yield args diff --git a/tests/brevitas_examples/test_quantize_model.py b/tests/brevitas_examples/test_quantize_model.py index 6a7184131..8ab34d5db 100644 --- a/tests/brevitas_examples/test_quantize_model.py +++ b/tests/brevitas_examples/test_quantize_model.py @@ -14,6 +14,8 @@ from brevitas.nn import QuantReLU from brevitas.quant_tensor import QuantTensor from brevitas_examples.imagenet_classification.ptq.ptq_common import quantize_model +from tests.marker import jit_disabled_for_local_loss +from tests.marker import jit_disabled_for_mock # CONSTANTS IMAGE_DIM = 16 @@ -568,6 +570,7 @@ def test_layerwise_percentile_for_calibration(simple_model, act_quant_percentile @pytest.mark.parametrize("quant_granularity", ["per_tensor", "per_channel"]) +@jit_disabled_for_local_loss() def test_layerwise_param_method_mse(simple_model, quant_granularity): """ We test layerwise quantization, with the weight and activation quantization `mse` parameter diff --git a/tests/marker.py b/tests/marker.py index f11dc7a4a..d4ae6d325 100644 --- a/tests/marker.py +++ b/tests/marker.py @@ -50,6 +50,16 @@ def skip_wrapper(f): return skip_wrapper +def jit_disabled_for_local_loss(): + skip = config.JIT_ENABLED + + def skip_wrapper(f): + return pytest.mark.skipif( + skip, reason=f'Local loss functions (e.g., MSE) require JIT to be disabled')(f) + + return skip_wrapper + + def jit_disabled_for_dynamic_quant_act(): skip = config.JIT_ENABLED