Skip to content

Commit

Permalink
Feat (core): quantized scale/zero_point (#1038)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Dec 3, 2024
1 parent 5e94dae commit 6ca9e31
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 2 deletions.
27 changes: 27 additions & 0 deletions src/brevitas/core/restrict_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,30 @@ def forward(self, x: Tensor):
x = self.float_to_int_impl(x)
x = self.power_of_two(x)
return x


class QuantRestrictValue(brevitas.jit.ScriptModule):

def __init__(self, restrict_value_float_to_int_impl: Module):
super(QuantRestrictValue, self).__init__()
self.float_to_int_impl = restrict_value_float_to_int_impl

def restrict_init_float(self, x: float):
return Identity()

def restrict_init_tensor(self, x: torch.Tensor):
return Identity()

def restrict_init_module(self):
return Identity()

def restrict_init_inplace_module(self):
return Identity()

def retrocompatibility_op(self, x):
return Identity()

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
o, *_ = self.float_to_int_impl(x)
return o
22 changes: 20 additions & 2 deletions src/brevitas/core/zero_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ def forward(self, zero_point: Tensor, scale: Tensor, bit_width: Tensor) -> Tenso
return out


class _ScaleShiftQuantZeroPoint(brevitas.jit.ScriptModule):

def __init__(self, zp_int_quant: Module) -> None:
super(_ScaleShiftQuantZeroPoint, self).__init__()
self.zp_int_quant = zp_int_quant

@brevitas.jit.script_method
def forward(self, zero_point: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor:
quant_zp, *_ = self.zp_int_quant(zero_point)
return quant_zp


class StatsFromParameterZeroPoint(brevitas.jit.ScriptModule):

def __init__(
Expand All @@ -70,15 +82,21 @@ def __init__(
zero_point_stats_input_concat_dim: int,
zero_point_stats_impl: Module,
zero_point_shape: Tuple[int, ...],
tracked_parameter_list: List[torch.nn.Parameter]) -> None:
tracked_parameter_list: List[torch.nn.Parameter],
scale_shift_zero_point_impl: Optional[Module] = None) -> None:
super(StatsFromParameterZeroPoint, self).__init__()
self.parameter_list_stats = _ParameterListStats(
zero_point_stats_impl,
zero_point_shape,
zero_point_stats_input_view_shape_impl,
zero_point_stats_input_concat_dim,
tracked_parameter_list)
self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point)
# This is for backward compatibility. Having int_quant/quantize_zero_point required for this
# interface but not for the else seems a bit off and might require some clean-up.
if scale_shift_zero_point_impl is None:
self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point)
else:
self.scale_shift_zero_point = scale_shift_zero_point_impl

@brevitas.jit.script_method
def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> torch.Tensor:
Expand Down
131 changes: 131 additions & 0 deletions tests/brevitas/core/test_scaling_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from dependencies import this
from dependencies import value
import torch

from brevitas.core.quant.int import RescalingIntQuant
from brevitas.core.restrict_val import QuantRestrictValue
from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE
from brevitas.inject.enum import ScalingPerOutputType
import brevitas.nn as qnn
from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector
from brevitas.quant.scaled_int import Int8WeightPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat

ZP_BIT_WIDTH = 6
SCALE_BIT_WIDTH = 5


class QuantScalingInt(Int8WeightPerTensorFloat):
bit_width = SCALE_BIT_WIDTH
module = (this << 1).module
tracked_parameter_list = (this << 1).tracked_parameter_list
upstream_scaling = (this << 1).scaling_per_output_type
rescaling_int_quant = RescalingIntQuant

@value
def scaling_shape(
scaling_per_output,
scaling_per_output_channel_shape,
expanded_groupwise_shape,
group_dim,
upstream_scaling):
if scaling_per_output == ScalingPerOutputType.TENSOR:
scaling = SCALAR_SHAPE
elif scaling_per_output == ScalingPerOutputType.CHANNEL:
scaling = scaling_per_output_channel_shape
elif scaling_per_output == ScalingPerOutputType.GROUP:
# Scaling shape is like expanded_groupwise_shape but has 1 in position group_dim + 1
assert expanded_groupwise_shape is not None, "Per Group scaling not correctly configured"
assert group_dim is not None, "Per Group scaling not correctly configured"
size = list(expanded_groupwise_shape)
size[group_dim + 1] = 1
scaling = tuple(size)

# When quantizing scale of groupwise, there will be one extra dim compared to the normal case
if upstream_scaling == ScalingPerOutputType.GROUP:
scaling = list(scaling)
scaling.insert(-1, 1)
scaling = tuple(scaling)
return scaling


from brevitas.core.zero_point import _ScaleShiftQuantZeroPoint


class QuantZPInt(Int8WeightPerTensorFloat):
module = (this << 1).module
tracked_parameter_list = (this << 1).tracked_parameter_list
upstream_scaling = (this << 1).scaling_per_output_type
rescaling_int_quant = RescalingIntQuant
bit_width = ZP_BIT_WIDTH
quantize_zero_point = True
scaling_per_output_type = ScalingPerOutputType.CHANNEL

@value
def scaling_shape(
scaling_per_output,
scaling_per_output_channel_shape,
expanded_groupwise_shape,
group_dim,
upstream_scaling):
if scaling_per_output == ScalingPerOutputType.TENSOR:
scaling = SCALAR_SHAPE
elif scaling_per_output == ScalingPerOutputType.CHANNEL:
scaling = scaling_per_output_channel_shape
elif scaling_per_output == ScalingPerOutputType.GROUP:
# Scaling shape is like expanded_groupwise_shape but has 1 in position group_dim + 1
assert expanded_groupwise_shape is not None, "Per Group scaling not correctly configured"
assert group_dim is not None, "Per Group scaling not correctly configured"
size = list(expanded_groupwise_shape)
size[group_dim + 1] = 1
scaling = tuple(size)

# When quantizing scale of groupwise, there will be one extra dim compared to the normal case
if upstream_scaling == ScalingPerOutputType.GROUP:
scaling = list(scaling)
scaling.insert(-1, 1)
scaling = tuple(scaling)
return scaling


class QuantScaleQuantZPInt8WeightPerTensorFloat(ShiftedUint8WeightPerTensorFloat):
proxy_class = GroupwiseWeightQuantProxyFromInjector
scaling_int_quant = QuantScalingInt
zp_int = QuantZPInt
restrict_scaling_impl = QuantRestrictValue
scaling_per_output_type = ScalingPerOutputType.GROUP
scale_shift_zero_point_impl = _ScaleShiftQuantZeroPoint
group_size = 32

@value
def restrict_value_float_to_int_impl():
return this.scaling_int_quant.rescaling_int_quant

@value
def zp_int_quant():
return this.zp_int.rescaling_int_quant


def test_quant_scale():

def hook_scale(module, inp):
inp = inp[0]
quant_scale, scale, zp, bit_width = module.float_to_int_impl(inp)
assert bit_width == SCALE_BIT_WIDTH
assert torch.allclose(quant_scale / scale, torch.round(quant_scale / scale))

def hook_zp(module, inp):
inp = inp[0]
quant_scale, scale, zp, bit_width = module.zp_int_quant(inp)
assert bit_width == ZP_BIT_WIDTH
assert torch.allclose(quant_scale / scale, torch.round(quant_scale / scale))

linear = qnn.QuantLinear(64, 768, weight_quant=QuantScaleQuantZPInt8WeightPerTensorFloat)
for module in linear.modules():
if isinstance(module, QuantRestrictValue):
module.register_forward_pre_hook(hook_scale)
for module in linear.modules():
if isinstance(module, _ScaleShiftQuantZeroPoint):
module.register_forward_pre_hook(hook_zp)

linear(torch.randn(1, 64))

0 comments on commit 6ca9e31

Please sign in to comment.