diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index 328ad63b3..eb250cc0d 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -8,6 +8,7 @@ from torch.nn import Module import brevitas +from brevitas.core.function_wrapper.misc import Identity from brevitas.core.quant.delay import DelayWrapper from brevitas.core.utils import StatelessBuffer from brevitas.function.ops_ste import round_ste @@ -138,7 +139,8 @@ def __init__( scaling_impl: Module, int_scaling_impl: Module, zero_point_impl: Module, - bit_width_impl: Module): + bit_width_impl: Module, + scaling_int_quant: Optional[Module] = None): super(RescalingIntQuant, self).__init__() self.int_quant = int_quant self.scaling_impl = scaling_impl diff --git a/src/brevitas/core/restrict_val.py b/src/brevitas/core/restrict_val.py index 7d6d83231..bc77134aa 100644 --- a/src/brevitas/core/restrict_val.py +++ b/src/brevitas/core/restrict_val.py @@ -170,3 +170,30 @@ def forward(self, x: Tensor): x = self.float_to_int_impl(x) x = self.power_of_two(x) return x + + +class QuantRestrictValue(brevitas.jit.ScriptModule): + + def __init__(self, restrict_value_float_to_int_impl: Module): + super(QuantRestrictValue, self).__init__() + self.float_to_int_impl = restrict_value_float_to_int_impl + + def restrict_init_float(self, x: float): + return Identity() + + def restrict_init_tensor(self, x: torch.Tensor): + return Identity() + + def restrict_init_module(self): + return Identity() + + def restrict_init_inplace_module(self): + return Identity() + + def retrocompatibility_op(self, x): + return Identity() + + @brevitas.jit.script_method + def forward(self, x: torch.Tensor): + o, *_ = self.float_to_int_impl(x) + return o diff --git a/src/brevitas/core/zero_point.py b/src/brevitas/core/zero_point.py index 796940f4f..499de376f 100644 --- a/src/brevitas/core/zero_point.py +++ b/src/brevitas/core/zero_point.py @@ -60,6 +60,20 @@ def forward(self, zero_point: Tensor, scale: Tensor, bit_width: Tensor) -> Tenso return out +class _ScaleShiftQuantZeroPoint(brevitas.jit.ScriptModule): + __constants__ = ['quantize_zero_point'] + + def __init__(self, zp_int_quant: Module, quantize_zero_point: bool) -> None: + super(_ScaleShiftQuantZeroPoint, self).__init__() + self.zp_int_quant = zp_int_quant + self.quantize_zero_point = quantize_zero_point + + @brevitas.jit.script_method + def forward(self, zero_point: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor: + quant_zp, scale, *_ = self.zp_int_quant(zero_point) + return quant_zp + + class StatsFromParameterZeroPoint(brevitas.jit.ScriptModule): def __init__( @@ -70,7 +84,8 @@ def __init__( zero_point_stats_input_concat_dim: int, zero_point_stats_impl: Module, zero_point_shape: Tuple[int, ...], - tracked_parameter_list: List[torch.nn.Parameter]) -> None: + tracked_parameter_list: List[torch.nn.Parameter], + scale_shit_zero_point_impl: Optional[Module] = None) -> None: super(StatsFromParameterZeroPoint, self).__init__() self.parameter_list_stats = _ParameterListStats( zero_point_stats_impl, @@ -78,7 +93,10 @@ def __init__( zero_point_stats_input_view_shape_impl, zero_point_stats_input_concat_dim, tracked_parameter_list) - self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point) + if scale_shit_zero_point_impl is None: + self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point) + else: + self.scale_shift_zero_point = scale_shit_zero_point_impl @brevitas.jit.script_method def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> torch.Tensor: diff --git a/tests/brevitas/core/test_scaling_quant.py b/tests/brevitas/core/test_scaling_quant.py new file mode 100644 index 000000000..dab8312e9 --- /dev/null +++ b/tests/brevitas/core/test_scaling_quant.py @@ -0,0 +1,127 @@ +from dependencies import this +from dependencies import value +import torch + +from brevitas.core.quant.int import RescalingIntQuant +from brevitas.core.restrict_val import QuantRestrictValue +from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE +from brevitas.inject.enum import ScalingPerOutputType +import brevitas.nn as qnn +from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector +from brevitas.quant.scaled_int import Int8WeightPerTensorFloat +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat + + +class QuantScalingInt(Int8WeightPerTensorFloat): + bit_width = 8 + module = (this << 1).module + tracked_parameter_list = (this << 1).tracked_parameter_list + upstream_scaling = (this << 1).scaling_per_output_type + rescaling_int_quant = RescalingIntQuant + + @value + def scaling_shape( + scaling_per_output, + scaling_per_output_channel_shape, + expanded_groupwise_shape, + group_dim, + upstream_scaling): + if scaling_per_output == ScalingPerOutputType.TENSOR: + scaling = SCALAR_SHAPE + elif scaling_per_output == ScalingPerOutputType.CHANNEL: + scaling = scaling_per_output_channel_shape + elif scaling_per_output == ScalingPerOutputType.GROUP: + # Scaling shape is like expanded_groupwise_shape but has 1 in position group_dim + 1 + assert expanded_groupwise_shape is not None, "Per Group scaling not correctly configured" + assert group_dim is not None, "Per Group scaling not correctly configured" + size = list(expanded_groupwise_shape) + size[group_dim + 1] = 1 + scaling = tuple(size) + + # When quantizing scale of groupwise, there will be one extra dim compared to the normal case + if upstream_scaling == ScalingPerOutputType.GROUP: + scaling = list(scaling) + scaling.insert(-1, 1) + scaling = tuple(scaling) + return scaling + + +from brevitas.core.zero_point import _ScaleShiftQuantZeroPoint + + +class QuantZPInt(Int8WeightPerTensorFloat): + bit_width = 8 + module = (this << 1).module + tracked_parameter_list = (this << 1).tracked_parameter_list + upstream_scaling = (this << 1).scaling_per_output_type + rescaling_int_quant = RescalingIntQuant + bit_width = 6 + quantize_zero_point = True + scaling_per_output_type = ScalingPerOutputType.CHANNEL + + @value + def scaling_shape( + scaling_per_output, + scaling_per_output_channel_shape, + expanded_groupwise_shape, + group_dim, + upstream_scaling): + if scaling_per_output == ScalingPerOutputType.TENSOR: + scaling = SCALAR_SHAPE + elif scaling_per_output == ScalingPerOutputType.CHANNEL: + scaling = scaling_per_output_channel_shape + elif scaling_per_output == ScalingPerOutputType.GROUP: + # Scaling shape is like expanded_groupwise_shape but has 1 in position group_dim + 1 + assert expanded_groupwise_shape is not None, "Per Group scaling not correctly configured" + assert group_dim is not None, "Per Group scaling not correctly configured" + size = list(expanded_groupwise_shape) + size[group_dim + 1] = 1 + scaling = tuple(size) + + # When quantizing scale of groupwise, there will be one extra dim compared to the normal case + if upstream_scaling == ScalingPerOutputType.GROUP: + scaling = list(scaling) + scaling.insert(-1, 1) + scaling = tuple(scaling) + return scaling + + +class QuantScaleInt8WeightPerTensorFloat(ShiftedUint8WeightPerTensorFloat): + proxy_class = GroupwiseWeightQuantProxyFromInjector + scaling_int_quant = QuantScalingInt + zp_int = QuantZPInt + restrict_scaling_impl = QuantRestrictValue + scaling_per_output_type = ScalingPerOutputType.GROUP + scale_shit_zero_point_impl = _ScaleShiftQuantZeroPoint + group_size = 32 + + @value + def restrict_value_float_to_int_impl(): + return this.scaling_int_quant.rescaling_int_quant + + @value + def zp_int_quant(): + return this.zp_int.rescaling_int_quant + + +def test_quant_scale(): + + def hook_scale(module, inp): + inp = inp[0] + quant_scale, scale, *_ = module.float_to_int_impl(inp) + assert torch.allclose(quant_scale / scale, torch.round(quant_scale / scale)) + + def hook_zp(module, inp): + inp = inp[0] + quant_scale, scale, *_ = module.zp_int_quant(inp) + assert torch.allclose(quant_scale / scale, torch.round(quant_scale / scale)) + + linear = qnn.QuantLinear(64, 768, weight_quant=QuantScaleInt8WeightPerTensorFloat) + for module in linear.modules(): + if isinstance(module, QuantRestrictValue): + module.register_forward_pre_hook(hook_scale) + for module in linear.modules(): + if isinstance(module, _ScaleShiftQuantZeroPoint): + module.register_forward_pre_hook(hook_zp) + + linear(torch.randn(1, 64))