From 83f58885be9e1384b705b4299e1a23c5fb371880 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Thu, 21 Dec 2023 15:37:15 +0000 Subject: [PATCH] Feat (GPFA2Q): add A2Q bound to GPFQ --- src/brevitas/core/scaling/pre_scaling.py | 23 +--- src/brevitas/function/ops.py | 15 +++ src/brevitas/graph/gpfq.py | 142 ++++++++++++++++++++--- src/brevitas/graph/gptq.py | 9 +- src/brevitas/graph/gpxq.py | 14 ++- 5 files changed, 159 insertions(+), 44 deletions(-) diff --git a/src/brevitas/core/scaling/pre_scaling.py b/src/brevitas/core/scaling/pre_scaling.py index 632242507..dd125396d 100644 --- a/src/brevitas/core/scaling/pre_scaling.py +++ b/src/brevitas/core/scaling/pre_scaling.py @@ -14,6 +14,7 @@ from brevitas.core.stats import SCALAR_SHAPE from brevitas.core.stats.stats_wrapper import _Stats from brevitas.function import abs_binary_sign_grad +from brevitas.function import get_upper_bound_on_l1_norm __all__ = ["ParameterPreScalingWeightNorm", "AccumulatorAwareParameterPreScaling"] @@ -170,25 +171,6 @@ def __init__( ) self.accumulator_bit_width = accumulator_bit_width_impl - @brevitas.jit.script_method - def get_upper_bound_on_l1_norm(self, input_bit_width: Tensor, input_is_signed: bool) -> Tensor: - """Calculate the upper bound on the l1-norm of the weights using the derivations from - `Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow Avoidance` - by I.Colbert, A.Pappalardo, and J.Petri-Koenig.""" - assert input_bit_width is not None, "A2Q relies on input bit-width." - assert input_is_signed is not None, "A2Q relies on input sign." - input_is_signed = float(input_is_signed) # 1. if signed else 0. - # This is the minimum of the two maximum magnitudes that P could take, which are -2^{P-1} - # and 2^{P-1}-1. Note that evaluating to -2^{P-1} would mean there is a possibility of overflow - # on the positive side of this range. - max_accumulator_bit_width = self.accumulator_bit_width() # P - max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1 - # This is the maximum possible magnitude that the input data could take. When the data is signed, - # this is 2^{N-1}. When the data is unsigned, this is 2^N - 1. We use a slightly looser bound here - # to simplify our derivations on the export validation. - max_input_mag_inverse = pow(2., input_is_signed - input_bit_width) - return max_accumulator_mag * max_input_mag_inverse - @brevitas.jit.script_method def forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tensor: """Takes weights as input and returns the pre-clipping scaling factor""" @@ -196,7 +178,8 @@ def forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: boo d_w = self.stats(weights) # denominator for weight normalization s = self.scaling_impl(weights) # s g = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) # g - T = self.get_upper_bound_on_l1_norm(input_bit_width, input_is_signed) # T / s + T = get_upper_bound_on_l1_norm( + self.accumulator_bit_width(), input_bit_width, input_is_signed) # T / s g = torch.clamp_max(g / s, T) value = d_w / g # calculating final pre-clipping scaling factor # re-apply clamp_min_ste from restrict_scaling_impl to the specified pre_scaling_min_val diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index f68ae9ede..ec326602d 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -201,3 +201,18 @@ def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_b device=mantissa_bit_width.device))) max_val = max_mantissa * (2 ** max_exponent) return max_val + + +def get_upper_bound_on_l1_norm( + accumulator_bit_width: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tensor: + """Calculate the upper bound on the l1-norm of the weights using the derivations from + `Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow Avoidance` + by I.Colbert, A.Pappalardo, and J.Petri-Koenig.""" + assert input_bit_width is not None, "A2Q relies on input bit-width." + assert input_is_signed is not None, "A2Q relies on input sign." + assert accumulator_bit_width is not None, "A2Q relies on accumulator bit-width." + input_is_signed = float(input_is_signed) # 1. if signed else 0. + max_accumulator_bit_width = accumulator_bit_width # P + max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1 + max_input_mag_inverse = pow(2., input_is_signed - input_bit_width) + return max_accumulator_mag * max_input_mag_inverse diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 2d312a549..cad5d9043 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -8,6 +8,7 @@ import torch import unfoldNd +from brevitas.function import get_upper_bound_on_l1_norm from brevitas.graph.gpxq import GPxQ from brevitas.graph.gpxq import gpxq_mode from brevitas.graph.gpxq import StopFwdException @@ -45,7 +46,9 @@ def __init__( use_quant_activations: bool = True, p: float = 1.0, return_forward_output: bool = False, - act_order: bool = False) -> None: + act_order: bool = False, + use_gpfa2q: bool = False, + accumulator_bit_width: Optional[int] = None) -> None: if not inplace: model = deepcopy(model) super().__init__( @@ -61,6 +64,10 @@ def __init__( self.model.forward = self.catch_stopfwd self.p = p + # GPFA2Q params + self.use_gpfa2q = use_gpfa2q + self.accumulator_bit_width = accumulator_bit_width + def catch_stopfwd(self, *args, **kwargs): # Collect quant input try: @@ -96,13 +103,23 @@ def catch_stopfwd(self, *args, **kwargs): def initialize_module_optimizer( self, layer, name, act_order, len_parallel_layers, create_weight_orig): - return GPFQ( - layer=layer, - name=name, - act_order=act_order, - len_parallel_layers=len_parallel_layers, - create_weight_orig=create_weight_orig, - p=self.p) + if not self.use_gpfa2q: + return GPFQ( + layer=layer, + name=name, + act_order=act_order, + len_parallel_layers=len_parallel_layers, + create_weight_orig=create_weight_orig, + p=self.p) + else: + return GPFA2Q( + layer=layer, + name=name, + act_order=act_order, + len_parallel_layers=len_parallel_layers, + create_weight_orig=create_weight_orig, + p=self.p, + accumulator_bit_width=self.accumulator_bit_width) class GPFQ(GPxQ): @@ -110,14 +127,7 @@ class GPFQ(GPxQ): Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main """ - def __init__( - self, - layer, - name, - act_order, - len_parallel_layers=1, - create_weight_orig=True, - p=1.0) -> None: + def __init__(self, layer, name, act_order, len_parallel_layers, create_weight_orig, p) -> None: super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig) @@ -256,3 +266,103 @@ def single_layer_update(self): del self.float_input del self.quantized_input + + +class GPFA2Q(GPFQ): + + def __init__( + self, + layer, + name, + act_order, + len_parallel_layers, + create_weight_orig, + accumulator_bit_width, + p) -> None: + GPFQ.__init__( + self, + layer=layer, + name=name, + act_order=act_order, + len_parallel_layers=len_parallel_layers, + create_weight_orig=create_weight_orig, + p=p) + self.accumulator_bit_width = accumulator_bit_width + assert self.accumulator_bit_width is not None + + def single_layer_update(self): + # raise error in case no quant-input is here + if self.quant_input is None: + raise ValueError( + 'Expected quant input to calculate Upper Bound on L1 norm, but received None') + weight = self.layer.weight.data + dev = weight.device + dtype = weight.dtype + if isinstance(self.layer, SUPPORTED_CONV_OP): + if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + weight = weight.transpose(1, 0) # This performs a view + weight = weight.flatten(1) + weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] + U = torch.zeros( + weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype) + self.float_input = self.float_input.to(dev) + self.quantized_input = self.quantized_input.to(dev) + + # get upper bound + input_bit_width = self.quant_input.bit_width + input_is_signed = self.quant_input.signed + T = get_upper_bound_on_l1_norm( + torch.tensor(self.accumulator_bit_width), input_bit_width, input_is_signed) + s = self.layer.quant_weight_scale() + s = s.view(self.groups, -1) # [Groups, OC/Groups] + + l1_norm = torch.zeros(weight.shape[:-1], device=dev) + + # We don't need full Hessian, we just need the diagonal + self.H_diag = self.quantized_input.transpose(2, 1).square().sum( + 2) # summing over Batch dimension + permutation_list = [] + for group_index in range(self.groups): + if self.act_order: + # Re-order Hessian_diagonal so that weights associated to + # higher magnitude activations are quantized first + perm = torch.argsort(self.H_diag[group_index, :], descending=True) + else: + # No permutation, permutation tensor is a ordered index + perm = torch.tensor(range(weight.shape[-1]), device=dev) + permutation_list.append(perm) + + for t in range(weight.shape[-1]): + for group_index in range(self.groups): + U[group_index] += torch.matmul( + weight[group_index, :, permutation_list[group_index][t]].unsqueeze(1), + self.float_input[group_index, :, permutation_list[group_index][t]].unsqueeze( + 0)) #[OC/Groups, 1] * [1, INSHAPE[1]] + norm = torch.linalg.norm( + self.quantized_input[group_index, :, permutation_list[group_index][t]], 2) ** 2 + if norm > 0: + q_arg = U[group_index].matmul( + self.quantized_input[group_index, :, + permutation_list[group_index][t]]) / norm + else: + q_arg = torch.zeros_like(U[group_index, :, 0]) + + weight[group_index, :, permutation_list[group_index][t]] = q_arg + q = self.get_quant_weights(t, 0, permutation_list) + + for group_index in range(self.groups): + candidate_l1 = l1_norm[group_index] + torch.abs(q[group_index]) + candidate_l1_mask = candidate_l1 > T * s[group_index] + if torch.any(candidate_l1_mask): + # set all values to 0 that are exceeding T * s + weight[group_index, :, permutation_list[group_index][t]][candidate_l1_mask] = 0 + q[group_index][candidate_l1_mask] = 0 + else: + l1_norm[group_index] = candidate_l1 + U[group_index] -= torch.matmul( + q[group_index].unsqueeze(1), + self.quantized_input[group_index, :, + permutation_list[group_index][t]].unsqueeze(0)) + + del self.float_input + del self.quantized_input diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index b10943f1b..56171ac6f 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -114,13 +114,8 @@ class GPTQ(GPxQ): """ def __init__( - self, - layer, - name, - act_order, - len_parallel_layers=1, - create_weight_orig=True, - num_blocks=100) -> None: + self, layer, name, act_order, len_parallel_layers, create_weight_orig, + num_blocks) -> None: super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig) dev = self.layer.weight.device diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index 1279950a8..ddeef1c53 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -11,6 +11,8 @@ from typing import List, Optional, Set import warnings +import torch + from brevitas.graph.calibrate import DisableEnableQuantization import brevitas.nn as qnn from brevitas.quant_tensor import QuantTensor @@ -175,13 +177,23 @@ def process_input(self, inp): if self.layer.weight_quant_requires_quant_input: # Can minimize memory allocation by not storing actual values self.quant_input = QuantTensor( - value=None, + value=torch.empty( + 1, dtype=self.layer.weight.dtype, device=self.layer.weight.device), scale=inp.scale, zero_point=inp.zero_point, bit_width=inp.bit_width, signed=inp.signed, training=inp.training) inp = inp.value + elif self.layer.is_input_quant_enabled: + self.quant_input = QuantTensor( + value=torch.empty( + 1, dtype=self.layer.weight.dtype, device=self.layer.weight.device), + scale=self.layer.quant_input_scale(), + zero_point=self.layer.quant_input_zero_point(), + bit_width=self.layer.quant_input_bit_width(), + signed=self.layer.is_quant_input_signed, + training=self.layer.training) # If input is unbatched, add batch_size = 1 if len(inp.shape) == 1: