diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index fa6fa9004..f40b36035 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -572,6 +572,7 @@ def __init__( self, proxy_module, hqo_init_op_scale, + keepdim: bool, inner_stats_input_view_shape_impl: torch.nn.Module, scaling_min_val: Optional[float] = None, stats_reduce_dim: Optional[int] = None, @@ -601,11 +602,11 @@ def __init__( self.clamp_min_ste = ScalarClampMinSte(scaling_min_val) else: self.clamp_min_ste = Identity() + self.keepdim = keepdim def parameter_search(self, xl, x): best_loss = torch.tensor(float('inf'), device=x.device, dtype=x.dtype) candidate = xl - # candidate = self.input_view_shape_impl(candidate) best_candidate = candidate beta = self.beta with torch.no_grad(): @@ -614,24 +615,23 @@ def parameter_search(self, xl, x): self.set_local_loss_mode(True) quant_tensor = self.proxy_forward(x).detach() self.set_local_loss_mode(False) - loss = torch.abs(quant_tensor.value_ - x).mean() + loss = torch.abs(quant_tensor.value - x).mean() best_candidate = torch.where(loss < best_loss, candidate, best_candidate) if loss >= best_loss: break best_loss = torch.min(loss, best_loss) - W_e = shrink_lp_op(x - quant_tensor.value_, beta, self.lp_norm) + W_e = shrink_lp_op(x - quant_tensor.value, beta, self.lp_norm) zero_point = quant_tensor.zero_point num = self.input_view_shape_impl(x - W_e).detach() den = self.input_view_shape_impl( - torch.round(quant_tensor.value_ / quant_tensor.scale_) - zero_point).detach() + torch.round(quant_tensor.value / quant_tensor.scale) - zero_point).detach() mask = (num != 0.) & (den != 0.) if self.stats_reduce_dim is None: candidate = masked_median(num / den, mask) else: candidate = masked_median( - num / den, mask, dim=self.stats_reduce_dim, keepdim=True) - # candidate = self.input_view_shape_impl(candidate) + num / den, mask, dim=self.stats_reduce_dim, keepdim=self.keepdim) candidate = self.clamp_min_ste(candidate) bit_width = self.msb_clamp_bit_width_impl() int_threshold = self.int_scaling_impl(bit_width) @@ -672,11 +672,10 @@ class HalfQuadraticOptimizerZeroPoint(torch.nn.Module): def __init__( self, proxy_module, + keepdim: bool, hqo_init_op_zp: torch.nn.Module, inner_stats_input_view_shape_impl: torch.nn.Module, stats_reduce_dim: Optional[int] = None, - inner_expanded_zero_point_shape=None, - reshaped_zero_point_shape=None, hqo_beta_zp: float = 1e0, hqo_kappa_zp: float = 1.01, hqo_lp_norm_zp: float = .5, @@ -684,7 +683,6 @@ def __init__( super(HalfQuadraticOptimizerZeroPoint, self).__init__() self.hqo_init_op_zp = hqo_init_op_zp self.input_view_shape_impl = inner_stats_input_view_shape_impl - self.proxy_module = proxy_module self.proxy_forward = proxy_module.forward self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled) self.set_quantize_zero_point = lambda enabled: _set_quantize_zero_point( @@ -696,38 +694,34 @@ def __init__( self.kappa = hqo_kappa_zp self.lp_norm = hqo_lp_norm_zp self.hqo_iters = hqo_iters_zp - self.inner_expanded_zero_point_shape = inner_expanded_zero_point_shape - self.reshaped_zero_point_shape = reshaped_zero_point_shape + self.keepdim = keepdim def parameter_search(self, xl, x): best_loss = torch.tensor(float('inf'), device=x.device, dtype=x.dtype) candidate = xl - candidate = self.input_view_shape_impl(candidate) best_candidate = candidate with torch.no_grad(): for i in range(0, self.hqo_iters): self.internal_candidate = candidate self.set_local_loss_mode(True) - prev_state = _set_quantize_zero_point(self.proxy_module, False) quant_tensor = self.proxy_forward(x).detach() self.set_local_loss_mode(False) - _restore_quantize_zero_point(self.proxy_module, prev_state) - loss = torch.abs(quant_tensor.value - x).mean() + qt_value = self.input_view_shape_impl(quant_tensor.value) + qt_scale = self.input_view_shape_impl(quant_tensor.scale) + qt_int = self.input_view_shape_impl(quant_tensor.int()) + loss = torch.abs(qt_value - x).mean() best_candidate = torch.where(loss < best_loss, candidate, best_candidate) if loss >= best_loss: break best_loss = torch.min(loss, best_loss) - W_e = shrink_lp_op(x - quant_tensor.value, self.beta, self.lp_norm) + W_e = shrink_lp_op(x - qt_value, self.beta, self.lp_norm) + + val = self.input_view_shape_impl((x - W_e) - qt_int * qt_scale) - val = self.input_view_shape_impl((x - W_e) - - quant_tensor.int() * quant_tensor.scale) - if self.inner_expanded_zero_point_shape is not None: - val = val.reshape(self.inner_expanded_zero_point_shape) if self.stats_reduce_dim is None: candidate = torch.mean(val) else: - candidate = torch.mean(val, dim=self.stats_reduce_dim, keepdim=True) - candidate = self.input_view_shape_impl(candidate) + candidate = torch.mean(val, dim=self.stats_reduce_dim, keepdim=self.keepdim) self.beta *= self.kappa return best_candidate @@ -735,8 +729,6 @@ def optimize(self, x): x_view = self.input_view_shape_impl(x) init = self.hqo_init_op_zp(x_view).detach() - if self.reshaped_zero_point_shape is not None: - x = x.reshape(self.reshaped_zero_point_shape) best_candidate = self.parameter_search(init, x) diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index 1373bae2d..6f08a1bed 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -521,42 +521,23 @@ class MSEActZeroPoint(MSEZeroPoint): class HQOZeroPoint(ExtendedInjector): hqo_init_op_zp = NegativeMinOrZero + inner_stats_input_view_shape_impl = this.zero_point_stats_input_view_shape_impl stats_impl_zp = HalfQuadraticOptimizerZeroPoint - zero_point_stats_input_view_shape_impl = nn.Identity() @value def zero_point_stats_impl(): return this.stats_impl_zp - @value - def inner_stats_input_view_shape_impl(scaling_per_output): - if scaling_per_output == ScalingPerOutputType.CHANNEL: - return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS - elif scaling_per_output == ScalingPerOutputType.TENSOR: - return StatsInputViewShapeImpl.OVER_TENSOR - elif scaling_per_output == ScalingPerOutputType.GROUP: - return StatsInputViewShapeImpl.OVER_SUBCHANNEL_BLOCK - class HQOScale(ExtendedInjector): scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS - scaling_stats_input_view_shape_impl = nn.Identity() - + inner_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl stats_impl_scale = HalfQuadraticOptimizerScale @value def scaling_stats_impl(): return this.stats_impl_scale - @value - def inner_stats_input_view_shape_impl(scaling_per_output): - if scaling_per_output == ScalingPerOutputType.CHANNEL: - return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS - elif scaling_per_output == ScalingPerOutputType.TENSOR: - return StatsInputViewShapeImpl.OVER_TENSOR - elif scaling_per_output == ScalingPerOutputType.GROUP: - return StatsInputViewShapeImpl.OVER_SUBCHANNEL_BLOCK - class HQOAsymmetricScale(HQOScale): hqo_init_op_scale = AbsMinMax diff --git a/src/brevitas/quant/shifted_scaled_int.py b/src/brevitas/quant/shifted_scaled_int.py index 900471c98..c8e75312b 100644 --- a/src/brevitas/quant/shifted_scaled_int.py +++ b/src/brevitas/quant/shifted_scaled_int.py @@ -5,6 +5,7 @@ from brevitas.quant.base import HQOActZeroPoint from brevitas.quant.base import HQOAsymmetricScale from brevitas.quant.base import HQOZeroPoint +from brevitas.quant.scaled_int import Int8WeightPerTensorFloat from brevitas.quant.solver.act import ActQuantSolver from brevitas.quant.solver.bias import BiasQuantSolver from brevitas.quant.solver.trunc import TruncQuantSolver diff --git a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py index 29f39be6f..a958b31fa 100644 --- a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py @@ -3,6 +3,7 @@ import torch +from brevitas.function.ops_ste import round_ste from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor.base_quant_tensor import GroupwisIntQuantTensorBase from brevitas.quant_tensor.base_quant_tensor import QuantTensor @@ -101,29 +102,29 @@ def zero_point(self): new_value, new_scale, new_zp = self.expand() return new_zp - @property - def _pre_round_float_value(self): - value, scale, zp = self.expand() - if self.scale.dtype == torch.bfloat16: - value = value.type(torch.float32) - scale = scale.type(torch.float32) - minifloat_value = value / scale - fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width - int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale) - minifloat_value = minifloat_value / int_scale - return minifloat_value - @property def is_valid(self): with torch.no_grad(): - pre_round_minifloat_value = self._pre_round_float_value - rounded_minifloat_value = torch.round(pre_round_minifloat_value) - max_abs_diff = torch.max(torch.abs(pre_round_minifloat_value - rounded_minifloat_value)) + pre_round_int_value = self._pre_round_int_value + rounded_int_value = torch.round(pre_round_int_value) + max_abs_diff = torch.max(torch.abs(pre_round_int_value - rounded_int_value)) atol = BFLOAT16_IS_VALID_ATOL if self.value.dtype == torch.bfloat16 else IS_VALID_ATOL - is_minifloat = max_abs_diff < atol - # We are missing the checks about self being contained between max and min value - # given by mantissa, exponent, inf, nan, and saturating - return is_minifloat + is_int = max_abs_diff < atol + if self.bit_width >= 2: + if self.signed: + is_upper_b = (2.0 ** (self.bit_width - 1) - 1 >= rounded_int_value).all() + is_lower_b = (-2.0 ** (self.bit_width - 1) <= rounded_int_value).all() + else: + is_upper_b = (2.0 ** self.bit_width - 1 >= rounded_int_value).all() + is_lower_b = (0. <= rounded_int_value).all() + return (is_int & is_upper_b & is_lower_b).item() + else: # binary case + unique_vals = rounded_int_value.unique( + sorted=False, return_counts=False, return_inverse=False) + is_binary = unique_vals.view(-1).size()[0] == 2 + is_signed = (unique_vals < 0.).any().item() + sign_match = is_signed == self.signed + return is_int.item() and is_binary and sign_match @property def device(self): @@ -139,17 +140,38 @@ def device(self): raise RuntimeError("Value and metadata are on different devices") return value_device - def minifloat(self, float_datatype=True): - # TODO: Check if OCP and cast to proper data-type if matching - assert float_datatype, "Minifloat quant returns only higher precision dtype" - + @property + def _pre_round_int_value(self): + value = self.value + scale = self.scale + zero_point = self.zero_point + if self.scale.dtype == torch.bfloat16: + value = self.value.type(torch.float32) + scale = self.scale.type(torch.float32) + zero_point = self.zero_point.type(torch.float32) + int_value = value / scale + int_value = int_value + zero_point + return int_value + + def int(self, float_datatype=False): if self.is_valid: - fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width - int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale) - float_value = torch.round(self._pre_round_float_value) * int_scale - return float_value.type(self.scale.dtype) + int_value = round_ste(self._pre_round_int_value) + if float_datatype: + # Values at 8bit and lower can be represented exactly with float16 and bfloat16 + # otherwise (e.g. Int16 bias), we upscale to float32 + if self.bit_width <= 8.: + return int_value.type(self.scale.dtype) + else: + return int_value.type(torch.float32) + else: + if self.bit_width <= 8. and self.signed_t.item(): + return int_value.to(torch.int8) + elif self.bit_width <= 8. and not self.signed_t.item(): + return int_value.to(torch.uint8) + else: + return int_value.to(torch.int32) else: - raise RuntimeError(f"FloatQuantTensor not valid.") + raise RuntimeError(f"IntQuantTensor not valid.") @staticmethod def check_input_type(tensor): diff --git a/src/brevitas_examples/common/generative/quantizers.py b/src/brevitas_examples/common/generative/quantizers.py index e26aafba8..e9963c239 100644 --- a/src/brevitas_examples/common/generative/quantizers.py +++ b/src/brevitas_examples/common/generative/quantizers.py @@ -12,6 +12,7 @@ from brevitas.core.stats import NegativeMinOrZero from brevitas.core.stats.stats_op import HalfQuadraticOptimizerZeroPoint from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE +from brevitas.core.zero_point import StatsFromParameterZeroPoint from brevitas.inject import ExtendedInjector from brevitas.inject import this from brevitas.inject import value @@ -22,11 +23,13 @@ from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector from brevitas.proxy.groupwise_int_runtime_quant import GroupwiseActQuantProxyFromInjector from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector +from brevitas.quant.base import HQOWeightZeroPoint from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloat from brevitas.quant.scaled_int import Int8WeightPerChannelFloat +from brevitas.quant.scaled_int import Int8WeightPerChannelFloatHQO from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat @@ -63,34 +66,14 @@ class ShiftedUintWeightAsymmetricGroupQuant(ShiftedUint8WeightPerChannelFloat): scaling_per_output_type = ScalingPerOutputType.GROUP -from brevitas.quant.scaled_int import Int8WeightPerChannelFloatHQO - - -class ShiftedUintWeightAsymmetricGroupQuantHQO(Int8WeightPerChannelFloatHQO): +class ShiftedUintWeightAsymmetricGroupQuantHQO(HQOWeightZeroPoint, + ShiftedUint8WeightPerChannelFloat): """ Block / group / vector signed asymmetric weight quantizer with float scales and zero-points. """ proxy_class = GroupwiseWeightQuantProxyFromInjector scaling_per_output_type = ScalingPerOutputType.GROUP - # zero_point_input_shape = this.scaling_input_shape - # reshaped_zero_point_shape = this.reshaped_scaling_shape - # zero_point_shape = this.scaling_shape - # # inner_expanded_zero_point_shape = this.expanded_scaling_shape - # # expanded_zero_point_shape = this.expanded_scaling_shape - # zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl - # zero_point_stats_input_concat_dim = 0 - # # zero_point_impl = ExpandReshapeZeroPointWrapper - # zero_point_stats_impl = HalfQuadraticOptimizerZeroPoint - # hqo_init_op_zp = NegativeMinOrZero - # scaling_stats_impl = AbsMinMax - # keepdim = True - # # zero-point is converted to a parameter right away - # zero_point_impl = ParameterFromStatsFromParameterZeroPoint - # quantize_zero_point = False - # signed = False - # inner_stats_input_view_shape_impl = torch.nn.Identity() - class Int8DynamicActPerTensorFloat(DynamicActProxyMixin, Int8ActPerTensorFloat): """ diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index d2a6b3f22..d678b016d 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -7,9 +7,9 @@ import re import numpy as np -# from optimum.amd.brevitas.accelerate_utils import offload_model -# from optimum.amd.brevitas.accelerate_utils import remove_hooks -# from optimum.exporters.onnx import onnx_export_from_model +from optimum.amd.brevitas.accelerate_utils import offload_model +from optimum.amd.brevitas.accelerate_utils import remove_hooks +from optimum.exporters.onnx import onnx_export_from_model import torch from transformers import AutoModelForCausalLM from transformers import AutoTokenizer @@ -302,17 +302,17 @@ def main(): if args.weight_equalization: print("Apply weight equalization...") # In case of float16 model, we need to offload to account for missing ops - # model = offload_model(model) + model = offload_model(model) apply_weight_equalization(model) - # remove_hooks(model) + remove_hooks(model) print("Weight equalization applied.") if args.act_equalization is not None: - # offload_model(model) + offload_model(model) print("Apply act equalization (SmoothQuant)...") apply_act_equalization(model, args.act_equalization, calibration_loader) print("Act equalization applied.") - # remove_hooks(model) + remove_hooks(model) if not args.no_quantize: print("Applying model quantization...") @@ -339,7 +339,6 @@ def main(): quantize_embedding=args.quantize_embedding) # Tie back first/last layer weights in case they got untied print("Model quantization applied.") - # If any equalization has taken places, the embedding layer and the fully connected one are # not tied anymore, and they need to be treated as standalone, separate layers. # In all other cases we can tie them back so to preserve memory.