From 6ca9e311b9dec7a6561951965677c2f1ccd3fb3d Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 3 Dec 2024 14:32:29 +0000 Subject: [PATCH] Feat (core): quantized scale/zero_point (#1038) --- src/brevitas/core/restrict_val.py | 27 +++++ src/brevitas/core/zero_point.py | 22 +++- tests/brevitas/core/test_scaling_quant.py | 131 ++++++++++++++++++++++ 3 files changed, 178 insertions(+), 2 deletions(-) create mode 100644 tests/brevitas/core/test_scaling_quant.py diff --git a/src/brevitas/core/restrict_val.py b/src/brevitas/core/restrict_val.py index 7d6d83231..bc77134aa 100644 --- a/src/brevitas/core/restrict_val.py +++ b/src/brevitas/core/restrict_val.py @@ -170,3 +170,30 @@ def forward(self, x: Tensor): x = self.float_to_int_impl(x) x = self.power_of_two(x) return x + + +class QuantRestrictValue(brevitas.jit.ScriptModule): + + def __init__(self, restrict_value_float_to_int_impl: Module): + super(QuantRestrictValue, self).__init__() + self.float_to_int_impl = restrict_value_float_to_int_impl + + def restrict_init_float(self, x: float): + return Identity() + + def restrict_init_tensor(self, x: torch.Tensor): + return Identity() + + def restrict_init_module(self): + return Identity() + + def restrict_init_inplace_module(self): + return Identity() + + def retrocompatibility_op(self, x): + return Identity() + + @brevitas.jit.script_method + def forward(self, x: torch.Tensor): + o, *_ = self.float_to_int_impl(x) + return o diff --git a/src/brevitas/core/zero_point.py b/src/brevitas/core/zero_point.py index 796940f4f..f74fffae8 100644 --- a/src/brevitas/core/zero_point.py +++ b/src/brevitas/core/zero_point.py @@ -60,6 +60,18 @@ def forward(self, zero_point: Tensor, scale: Tensor, bit_width: Tensor) -> Tenso return out +class _ScaleShiftQuantZeroPoint(brevitas.jit.ScriptModule): + + def __init__(self, zp_int_quant: Module) -> None: + super(_ScaleShiftQuantZeroPoint, self).__init__() + self.zp_int_quant = zp_int_quant + + @brevitas.jit.script_method + def forward(self, zero_point: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor: + quant_zp, *_ = self.zp_int_quant(zero_point) + return quant_zp + + class StatsFromParameterZeroPoint(brevitas.jit.ScriptModule): def __init__( @@ -70,7 +82,8 @@ def __init__( zero_point_stats_input_concat_dim: int, zero_point_stats_impl: Module, zero_point_shape: Tuple[int, ...], - tracked_parameter_list: List[torch.nn.Parameter]) -> None: + tracked_parameter_list: List[torch.nn.Parameter], + scale_shift_zero_point_impl: Optional[Module] = None) -> None: super(StatsFromParameterZeroPoint, self).__init__() self.parameter_list_stats = _ParameterListStats( zero_point_stats_impl, @@ -78,7 +91,12 @@ def __init__( zero_point_stats_input_view_shape_impl, zero_point_stats_input_concat_dim, tracked_parameter_list) - self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point) + # This is for backward compatibility. Having int_quant/quantize_zero_point required for this + # interface but not for the else seems a bit off and might require some clean-up. + if scale_shift_zero_point_impl is None: + self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point) + else: + self.scale_shift_zero_point = scale_shift_zero_point_impl @brevitas.jit.script_method def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> torch.Tensor: diff --git a/tests/brevitas/core/test_scaling_quant.py b/tests/brevitas/core/test_scaling_quant.py new file mode 100644 index 000000000..ba3d8ef7c --- /dev/null +++ b/tests/brevitas/core/test_scaling_quant.py @@ -0,0 +1,131 @@ +from dependencies import this +from dependencies import value +import torch + +from brevitas.core.quant.int import RescalingIntQuant +from brevitas.core.restrict_val import QuantRestrictValue +from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE +from brevitas.inject.enum import ScalingPerOutputType +import brevitas.nn as qnn +from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector +from brevitas.quant.scaled_int import Int8WeightPerTensorFloat +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat + +ZP_BIT_WIDTH = 6 +SCALE_BIT_WIDTH = 5 + + +class QuantScalingInt(Int8WeightPerTensorFloat): + bit_width = SCALE_BIT_WIDTH + module = (this << 1).module + tracked_parameter_list = (this << 1).tracked_parameter_list + upstream_scaling = (this << 1).scaling_per_output_type + rescaling_int_quant = RescalingIntQuant + + @value + def scaling_shape( + scaling_per_output, + scaling_per_output_channel_shape, + expanded_groupwise_shape, + group_dim, + upstream_scaling): + if scaling_per_output == ScalingPerOutputType.TENSOR: + scaling = SCALAR_SHAPE + elif scaling_per_output == ScalingPerOutputType.CHANNEL: + scaling = scaling_per_output_channel_shape + elif scaling_per_output == ScalingPerOutputType.GROUP: + # Scaling shape is like expanded_groupwise_shape but has 1 in position group_dim + 1 + assert expanded_groupwise_shape is not None, "Per Group scaling not correctly configured" + assert group_dim is not None, "Per Group scaling not correctly configured" + size = list(expanded_groupwise_shape) + size[group_dim + 1] = 1 + scaling = tuple(size) + + # When quantizing scale of groupwise, there will be one extra dim compared to the normal case + if upstream_scaling == ScalingPerOutputType.GROUP: + scaling = list(scaling) + scaling.insert(-1, 1) + scaling = tuple(scaling) + return scaling + + +from brevitas.core.zero_point import _ScaleShiftQuantZeroPoint + + +class QuantZPInt(Int8WeightPerTensorFloat): + module = (this << 1).module + tracked_parameter_list = (this << 1).tracked_parameter_list + upstream_scaling = (this << 1).scaling_per_output_type + rescaling_int_quant = RescalingIntQuant + bit_width = ZP_BIT_WIDTH + quantize_zero_point = True + scaling_per_output_type = ScalingPerOutputType.CHANNEL + + @value + def scaling_shape( + scaling_per_output, + scaling_per_output_channel_shape, + expanded_groupwise_shape, + group_dim, + upstream_scaling): + if scaling_per_output == ScalingPerOutputType.TENSOR: + scaling = SCALAR_SHAPE + elif scaling_per_output == ScalingPerOutputType.CHANNEL: + scaling = scaling_per_output_channel_shape + elif scaling_per_output == ScalingPerOutputType.GROUP: + # Scaling shape is like expanded_groupwise_shape but has 1 in position group_dim + 1 + assert expanded_groupwise_shape is not None, "Per Group scaling not correctly configured" + assert group_dim is not None, "Per Group scaling not correctly configured" + size = list(expanded_groupwise_shape) + size[group_dim + 1] = 1 + scaling = tuple(size) + + # When quantizing scale of groupwise, there will be one extra dim compared to the normal case + if upstream_scaling == ScalingPerOutputType.GROUP: + scaling = list(scaling) + scaling.insert(-1, 1) + scaling = tuple(scaling) + return scaling + + +class QuantScaleQuantZPInt8WeightPerTensorFloat(ShiftedUint8WeightPerTensorFloat): + proxy_class = GroupwiseWeightQuantProxyFromInjector + scaling_int_quant = QuantScalingInt + zp_int = QuantZPInt + restrict_scaling_impl = QuantRestrictValue + scaling_per_output_type = ScalingPerOutputType.GROUP + scale_shift_zero_point_impl = _ScaleShiftQuantZeroPoint + group_size = 32 + + @value + def restrict_value_float_to_int_impl(): + return this.scaling_int_quant.rescaling_int_quant + + @value + def zp_int_quant(): + return this.zp_int.rescaling_int_quant + + +def test_quant_scale(): + + def hook_scale(module, inp): + inp = inp[0] + quant_scale, scale, zp, bit_width = module.float_to_int_impl(inp) + assert bit_width == SCALE_BIT_WIDTH + assert torch.allclose(quant_scale / scale, torch.round(quant_scale / scale)) + + def hook_zp(module, inp): + inp = inp[0] + quant_scale, scale, zp, bit_width = module.zp_int_quant(inp) + assert bit_width == ZP_BIT_WIDTH + assert torch.allclose(quant_scale / scale, torch.round(quant_scale / scale)) + + linear = qnn.QuantLinear(64, 768, weight_quant=QuantScaleQuantZPInt8WeightPerTensorFloat) + for module in linear.modules(): + if isinstance(module, QuantRestrictValue): + module.register_forward_pre_hook(hook_scale) + for module in linear.modules(): + if isinstance(module, _ScaleShiftQuantZeroPoint): + module.register_forward_pre_hook(hook_zp) + + linear(torch.randn(1, 64))