diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index d206e016a..fddbfd892 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from abc import ABC +from copy import deepcopy from functools import partial import sys @@ -279,6 +280,7 @@ def forward_hook_wbiol(self, module, inp, output, name): # Compute float reference self.disable_act_quantization(module, is_training=False) self.disable_param_quantization(module, is_training=False) + out_float = module.forward(*inp) # Required to avoid infinite recursion self.collect_float_mean(module, out_float, name) self.enable_act_quantization(module, is_training=False) diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py new file mode 100644 index 000000000..01dc11a82 --- /dev/null +++ b/src/brevitas/graph/gpfq.py @@ -0,0 +1,229 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from copy import deepcopy +from typing import List, Optional + +import numpy as np +import torch +import unfoldNd + +from brevitas.graph.gpxq import GPxQ +from brevitas.graph.gpxq import gpxq_mode +from brevitas.graph.gpxq import StopFwdException +from brevitas.graph.gpxq import SUPPORTED_CONV_OP +import brevitas.nn as qnn + + +class gpfq_mode(gpxq_mode): + """ + Apply GPFQ algorithm. + + Args: + model (Module): The model to quantize with GPFQ + inplace (bool): Wheter to apply GPFQ inplace or perform a deepcopy. Default: True + use_quant_activations (bool): Wheter to leave quantize activations enabled while performing + GPFQ. Default: False + + Example: + >>> with torch.no_grad(): + >>> with gpfq_mode(model) as gpfq: + >>> gpfq_model = gpfq.model + >>> for i in tqdm(range(gpfq.num_layers)): + >>> for img, t in calib_loader: + >>> img = img.cuda() + >>> gpfq_model(img) + >>> gpfq.update() + """ + + def __init__( + self, + model, + group_of_parallel_layers: Optional[List[str]] = None, + inplace: bool = True, + create_weight_orig: bool = True, + use_quant_activations: bool = True, + p: int = 0.25, + return_forward_output: bool = False, + act_order: bool = False) -> None: + if not inplace: + model = deepcopy(model) + super().__init__( + model, + group_of_parallel_layers, + inplace, + create_weight_orig, + use_quant_activations, + act_order, + return_forward_output) + + self.orig_forward = self.model.forward + self.model.forward = self.catch_stopfwd + self.class_implementation = GPFQ + GPFQ.p = p + + def catch_stopfwd(self, *args, **kwargs): + # Collect quant input + try: + self.orig_forward(*args, **kwargs) + except StopFwdException: + pass + + # Disable quantization + self.disable_quant_inference.disable_param_quantization(self.model, is_training=False) + self.disable_quant_inference.disable_act_quantization(self.model, is_training=False) + # Collect float input + try: + self.orig_forward(*args, **kwargs) + except StopFwdException: + pass + + # Re-enable quantization. If activation quantization is disabled, + # we also disable bias quantization + self.disable_quant_inference.enable_param_quantization(self.model, is_training=False) + if self.use_quant_activations: + self.disable_quant_inference.enable_act_quantization(self.model, is_training=False) + else: + self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False) + + if self.return_forward_output: + # If we want to return the output of the network, we need to disable all hooks + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = True + out = self.orig_forward(*args, **kwargs) + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = False + return out + + +class GPFQ(GPxQ): + """ + Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main + """ + p = 0.25 + + def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig=True) -> None: + + if act_order: + raise ValueError("Act_order is not supported in GPFQ") + + super().__init__(layer, name, act_order, parallel_layers, create_weight_orig) + self.float_input = None + self.quantized_input = None + self.index_computed = False + self.p = GPFQ.p + + def update_batch(self, module, input, current_layer): + if self.disable_pre_forward_hook: + return input + + # Update reference to current layer + current_layer.layer_names.add(self.name) + is_quant_disabled = module.weight_quant.disable_quant + + inp = self.process_input(input) + batch_size = inp.shape[0] + + # Preprocess the input to compute the Hessian + if isinstance(self.layer, qnn.QuantLinear): + if len(inp.shape) > 2: + inp = inp.reshape((-1, sum(inp.shape[2:]))) + # For QuantLinear layer, groups will be 1 + inp_processed = inp.unsqueeze(0) + + if isinstance(self.layer, SUPPORTED_CONV_OP): + # Pick the correct unfoldNd class + if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + unfold_impl = unfoldNd.UnfoldTransposeNd + else: + unfold_impl = unfoldNd.UnfoldNd + + unfold = unfold_impl( + self.layer.kernel_size, + dilation=self.layer.dilation, + padding=self.layer.padding, + stride=self.layer.kernel_size) + + # Split input based on how many groups in convolution + inp_by_group = torch.chunk(inp, self.groups, 1) + inp_processed = [] + # Preprocess input by group + for i, inp in enumerate(inp_by_group): + + inp = unfold(inp) + + batch_size, num_blocks = inp.shape[0], inp.shape[-1] + inp = torch.transpose(inp, 1, 2) # shape (B, L, C*kernel_size[0]*kernel_size[1]) + inp = inp.reshape(-1, inp.size(-1)) # shape (B*L, C*kernel_size[0]*kernel_size[1]) + + if not self.index_computed: + self.index_computed = True + self.rand_indices = np.concatenate([ + np.random.choice( + np.arange(num_blocks * i, num_blocks * (i + 1)), + size=int( + self.p * num_blocks + 1 if self.p != 1 else self.p * num_blocks)) + for i in range(batch_size)]) # need to define self.p (probability) + + indexes = self.rand_indices + if np.max(self.rand_indices) > inp.shape[0]: + indexes = self.rand_indices < inp.shape[0] + indexes = self.rand_indices[indexes] + + inp = inp[indexes] + inp_processed.append(inp) + inp_processed = torch.stack(inp_processed) + + if is_quant_disabled: + if self.float_input is None: + self.float_input = inp_processed + else: + self.float_input = torch.cat([self.float_input, inp_processed], dim=1) + else: + if self.quantized_input is None: + self.quantized_input = inp_processed + else: + self.quantized_input = torch.cat([self.quantized_input, inp_processed], dim=1) + # If we are executing GPFQ with group of parallel layers, we keep track of how many forward + # we executed. Once we executed as many as the number of parallel_layers, we raise + # StopFwdException + current_layer.forward_count += 1 + if current_layer.forward_count == len(self.parallel_layers): + current_layer.forward_count = 0 + raise StopFwdException + + def single_layer_update(self): + weight = self.layer.weight.data + dev = weight.device + dtype = weight.dtype + if isinstance(self.layer, SUPPORTED_CONV_OP): + if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + weight = weight.transpose(1, 0) # This performs a view + weight = weight.flatten(1) + weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] + U = torch.zeros( + weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype) + self.float_input = self.float_input.to(dev) + self.quantized_input = self.quantized_input.to(dev) + permutation_list = [torch.tensor(range(weight.shape[-1]))] + for t in range(weight.shape[-1]): + for group_index in range(self.groups): + U[group_index] += torch.matmul( + weight[group_index, :, t].unsqueeze(1), + self.float_input[group_index, :, + t].unsqueeze(0)) #[OC/Groups, 1] * [1, INSHAPE[1]] + norm = torch.linalg.norm(self.quantized_input[group_index, :, t], 2) ** 2 + if norm > 0: + q_arg = U[group_index].matmul(self.quantized_input[group_index, :, t]) / norm + else: + q_arg = torch.zeros_like(U[group_index, :, 0]) + + weight[group_index, :, t] = q_arg + q = self.get_quant_weights(t, 0, permutation_list) + for group_index in range(self.groups): + U[group_index] -= torch.matmul( + q[group_index].unsqueeze(1), + self.quantized_input[group_index, :, t].unsqueeze(0)) + + del self.float_input + del self.quantized_input diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 8f8ffb6ae..b224e0a37 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -2,11 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from copy import deepcopy -from dataclasses import dataclass -from dataclasses import field -from functools import partial import math -from operator import attrgetter from typing import List, Optional, Set import warnings @@ -16,28 +12,16 @@ from torch.linalg import LinAlgError except: LinAlgError = RuntimeError - import unfoldNd -from brevitas.graph.calibrate import DisableEnableQuantization +from brevitas.graph.gpxq import GPxQ +from brevitas.graph.gpxq import gpxq_mode +from brevitas.graph.gpxq import StopFwdException +from brevitas.graph.gpxq import SUPPORTED_CONV_OP import brevitas.nn as qnn -from brevitas.quant_tensor import QuantTensor - -SUPPORTED_CONV_OP = ( - qnn.QuantConv2d, qnn.QuantConv1d, qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d) - - -class StopFwdException(Exception): - pass -@dataclass -class LayerHandler: - layer_names: Set = field(default_factory=set) - forward_count: int = 0 - - -class gptq_mode: +class gptq_mode(gpxq_mode): """ Apply GPTQ algorithm https://arxiv.org/abs/2210.17323. @@ -63,99 +47,28 @@ def __init__( model, group_of_parallel_layers: Optional[List[str]] = None, inplace: bool = True, + create_weight_orig: bool = True, use_quant_activations: bool = True, num_blocks: int = 100, - act_order: bool = False, - return_forward_output: bool = False) -> None: + return_forward_output: bool = False, + act_order: bool = False) -> None: if not inplace: model = deepcopy(model) - self.model = model - self.use_quant_activations = use_quant_activations - self.hook_dict = dict() - self.gptq_layers = dict() - # reference for each layer to update - self.current_layer = LayerHandler() - # How many layer to optimize - self.num_layers = 0 - # Quantize following magnitude of activation - self.act_order = act_order - # How many subblock to use during GPTQ for each layer - self.num_blocks = num_blocks + super().__init__( + model, + group_of_parallel_layers, + inplace, + create_weight_orig, + use_quant_activations, + act_order, + return_forward_output) - self.disable_quant_inference = DisableEnableQuantization() self.orig_forward = self.model.forward self.model.forward = self.catch_stopfwd - self.group_of_parallel_layers = group_of_parallel_layers - self.return_forward_output = return_forward_output - - def _is_module_supported(self, module): - if isinstance(module, SUPPORTED_CONV_OP): - return True - elif isinstance(module, qnn.QuantLinear): - return True - else: - return False - - def __enter__(self): - # The user can specify on which layers to apply gptq in parallel. - # All the others will be executed sequentially - dict_of_layers = { - name: [(name, module)] for name, - module in self.model.named_modules() if self._is_module_supported(module)} - if self.group_of_parallel_layers is not None: - for parallel_layers in self.group_of_parallel_layers: - for name in parallel_layers: - if name not in dict_of_layers: - raise ValueError( - "The layer {} is not present in the model or it is not supported for GPTQ" - .format(name)) - del dict_of_layers[name] - names = '_'.join(parallel_layers) - dict_of_layers[names] = [ - (name, attrgetter(name)(self.model)) for name in parallel_layers] - - # Print warning if hooks are attached to any module, since the normal forward flow of the - # network is highly disrupted during GPTQ - for _, parallel_layers in dict_of_layers.items(): - for name, module in parallel_layers: - if len(module._forward_hooks) > 0 or len(module._forward_pre_hooks): - warnings.warn( - f'Hooks detected during setup for GPTQ. ' - f'Behaviour might deviate from what expected.') - - # Attach hooks for GPTQ - if self._is_module_supported(module): - gptq = GPTQ( - module, - name, - num_blocks=self.num_blocks, - act_order=self.act_order, - parallel_layers=parallel_layers) - hook_fn = partial(gptq.update_batch, current_layer=self.current_layer) - self.hook_dict[name] = module.register_forward_pre_hook(hook_fn) - self.gptq_layers[name] = gptq - if not self.use_quant_activations: - self.disable_quant_inference.disable_act_quantization( - self.model, is_training=self.model.training) - self.disable_quant_inference.disable_bias_quantization( - self.model, is_training=self.model.training) - - self.num_layers = len(dict_of_layers) - return self - - def __exit__(self, type, value, traceback): - self.model.forward = self.orig_forward - if not self.use_quant_activations: - self.disable_quant_inference.enable_act_quantization( - self.model, is_training=self.model.training) - self.disable_quant_inference.enable_bias_quantization( - self.model, is_training=self.model.training) - - def update(self): - for name in self.current_layer.layer_names: - self.gptq_layers[name].single_layer_update() - self.hook_dict[name].remove() - self.current_layer.layer_names.clear() + # How many subblock to use during GPTQ for each layer + self.num_blocks = num_blocks + self.class_implementation = GPTQ + GPTQ.num_blocks = num_blocks def catch_stopfwd(self, *args, **kwargs): try: @@ -165,15 +78,15 @@ def catch_stopfwd(self, *args, **kwargs): finally: if self.return_forward_output: # If we want to return the output of the network, we need to disable all hooks - for name, gptq_class in self.gptq_layers.items(): - gptq_class.disable_pre_forward_hook = True + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = True out = self.orig_forward(*args, **kwargs) - for name, gptq_class in self.gptq_layers.items(): - gptq_class.disable_pre_forward_hook = False + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = False return out -class GPTQ(): +class GPTQ(GPxQ): """ Adapted from https://github.com/IST-DASLab/gptq, released under the following LICENSE: @@ -191,74 +104,29 @@ class GPTQ(): See the License for the specific language governing permissions and limitations under the License. """ + num_blocks = 100 - def __init__(self, layer, name, num_blocks, act_order, parallel_layers=1) -> None: - self.layer = layer - self.name = name - self.num_blocks = num_blocks - self.act_order = act_order + def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig=True) -> None: + super().__init__(layer, name, act_order, parallel_layers, create_weight_orig) - weight = layer.weight.data - dev = weight.device - - # By default, use groups = 1 - self.groups = 1 - if isinstance(self.layer, SUPPORTED_CONV_OP): - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): - weight = weight.transpose(1, 0) # This performs a view - weight = weight.flatten(1) - self.groups = self.layer.groups - - # Number of rows is equal to the output channels (OC) - self.rows = weight.shape[0] - # Number of columns is equal to the input channels (IC) - self.columns = weight.shape[1] + dev = self.layer.weight.device # Define how many columns to update in each mini-block - self.blocksize = math.ceil(self.columns / self.num_blocks) + self.blocksize = math.ceil(self.columns / GPTQ.num_blocks) # Initialize Hessian matrix and counter. We need it in float32 to compute the inverse self.H = torch.zeros((self.groups, self.columns, self.columns), device=dev, dtype=torch.float32) self.nsamples = 0 - self.parallel_layers = parallel_layers - - self.disable_pre_forward_hook = False - # Some layers require knowledge from quant inputs to compute quant weights - self.quant_input = None def update_batch(self, module, input, current_layer): if self.disable_pre_forward_hook: return input + # Update reference to current layer current_layer.layer_names.add(self.name) - - # Input is a tuple, so we take first element - inp = input[0] - # If using Quant Activations, inp could be QuantTensor - if isinstance(inp, QuantTensor): - if self.layer.weight_quant_requires_quant_input: - # Can minimize memory allocation by not storing actual values - self.quant_input = QuantTensor( - value=None, - scale=inp.scale, - zero_point=inp.zero_point, - bit_width=inp.bit_width, - signed=inp.signed, - training=inp.training) - inp = inp.value - - # If input is unbatched, add batch_size = 1 - if len(inp.shape) == 1: - warnings.warn("Found unbatched input, adding batch dimension equal to 1") - inp = inp.unsqueeze(0) - - # Define batch size before re-organizing the input - if hasattr(inp, 'names') and 'N' in inp.names: - batch_dim = inp.names.index('N') - inp.rename_(None) - inp = inp.transpose(0, batch_dim) + inp = self.process_input(input) batch_size = inp.shape[0] # Preprocess the input to compute the Hessian @@ -390,53 +258,3 @@ def single_layer_update(self, percdamp=.01): perm = permutation_list[group_index] weight[group_index, :, perm[i2:]] -= ( error_block[group_index].matmul(h_inv[group_index, i1:i2, i2:])).to(dtype) - - def get_quant_weights(self, i, i1, permutation_list): - # We need to recompute quant weights at runtime since our float weights are being updated - # Add offset in case of blockwise computation (e.g., GPTQ) - i = i1 + i - # For QuantLinear and for some QuantConvolutional layers, we exploit the possibility - # of quantizing only a subset of the entire matrix speeding up the computation of GPTQ - if isinstance(self.layer, qnn.QuantLinear): - index = permutation_list[0][i] - subtensor_slice_list = [None, (index, index + 1)] - q = self.layer.quant_weight( - subtensor_slice_list=subtensor_slice_list, - quant_input=self.quant_input).value.unsqueeze(0) # [1, OC, 1] - elif isinstance(self.layer, SUPPORTED_CONV_OP): - # For depthwise and ConvTranspose we fall back to quantizing the entire martix. - # For all other cases, we create a mask that represent the slicing we will perform on the weight matrix - # and we quantize only the selected dimensions. - if self.groups > 1 or (self.groups == 1 and isinstance( - self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d))): - - quant_weight = self.layer.quant_weight(quant_input=self.quant_input) - quant_weight = quant_weight.value - - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): - quant_weight = quant_weight.transpose(1, 0) # This performs a view - quant_weight = quant_weight.flatten(1) - quant_weight = quant_weight.view(self.groups, -1, quant_weight.shape[-1]) - - if self.act_order: - for ii, perm in enumerate(permutation_list): - quant_weight[ii, :, :] = quant_weight[ii, :, perm] - - q = quant_weight[:, :, i:i + 1] # [groups, OC/groups, 1] - else: - index = permutation_list[0][i] - shapes = self.layer.weight.shape[1:] - index_2d_to_nd = [] - residual_index = index.item() - for shape in shapes[::-1]: - index_2d_to_nd.append((residual_index % shape, residual_index % shape + 1)) - residual_index = residual_index // shape - index_2d_to_nd = index_2d_to_nd[::-1] - index_2d_to_nd.insert(0, None) - q = self.layer.quant_weight( - subtensor_slice_list=index_2d_to_nd, - quant_input=self.quant_input).value.flatten(1) # [OC, 1] - q = q.unsqueeze(0) # [1, OC, 1] - # We need to remove the last dim - q = q.squeeze(2) # [groups, OC/groups] or [1, OC] - return q diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py new file mode 100644 index 000000000..b13c46683 --- /dev/null +++ b/src/brevitas/graph/gpxq.py @@ -0,0 +1,252 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from abc import ABC +from abc import abstractmethod +from copy import deepcopy +from dataclasses import dataclass +from dataclasses import field +from functools import partial +from operator import attrgetter +from typing import List, Optional, Set +import warnings + +from brevitas.graph.calibrate import DisableEnableQuantization +import brevitas.nn as qnn +from brevitas.quant_tensor import QuantTensor + +SUPPORTED_CONV_OP = ( + qnn.QuantConv2d, qnn.QuantConv1d, qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d) + + +class StopFwdException(Exception): + pass + + +@dataclass +class LayerHandler: + layer_names: Set = field(default_factory=set) + forward_count: int = 0 + + +class gpxq_mode(ABC): + + def __init__( + self, + model, + group_of_parallel_layers: Optional[List[str]] = None, + inplace: bool = True, + create_weight_orig: bool = True, + use_quant_activations: bool = True, + act_order: bool = False, + return_forward_output: bool = False) -> None: + + if not inplace: + model = deepcopy(model) + self.model = model + self.create_weight_orig = create_weight_orig + self.use_quant_activations = use_quant_activations + self.hook_dict = dict() + self.gpxq_layers = dict() + # reference for each layer to update + self.current_layer = LayerHandler() + # How many layer to optimize + self.num_layers = 0 + # Quantize following magnitude of activation + self.act_order = act_order + # How many subblock to use during GPTQ for each layer + + self.disable_quant_inference = DisableEnableQuantization() + + self.group_of_parallel_layers = group_of_parallel_layers + self.return_forward_output = return_forward_output + + def _is_module_supported(self, module): + if isinstance(module, SUPPORTED_CONV_OP): + return True + elif isinstance(module, qnn.QuantLinear): + return True + else: + return False + + def __enter__(self): + # The user can specify on which layers to apply gptq in parallel. + # All the others will be executed sequentially + dict_of_layers = { + name: [(name, module)] for name, + module in self.model.named_modules() if self._is_module_supported(module)} + if self.group_of_parallel_layers is not None: + for parallel_layers in self.group_of_parallel_layers: + for name in parallel_layers: + if name not in dict_of_layers: + raise ValueError( + "The layer {} is not present in the model or it is not supported for GPTQ" + .format(name)) + del dict_of_layers[name] + names = '_'.join(parallel_layers) + dict_of_layers[names] = [ + (name, attrgetter(name)(self.model)) for name in parallel_layers] + + # Print warning if hooks are attached to any module, since the normal forward flow of the + # network is highly disrupted during GPxQ + for _, parallel_layers in dict_of_layers.items(): + for name, module in parallel_layers: + if len(module._forward_hooks) > 0 or len(module._forward_pre_hooks): + warnings.warn( + f'Hooks detected during setup for GPxQ. ' + f'Behaviour might deviate from what expected.') + + # Attach hooks for GPTQ + if self._is_module_supported(module): + gpxq = self.class_implementation( + module, + name, + act_order=self.act_order, + parallel_layers=parallel_layers, + create_weight_orig=self.create_weight_orig) + hook_fn = partial(gpxq.update_batch, current_layer=self.current_layer) + self.hook_dict[name] = module.register_forward_pre_hook(hook_fn) + self.gpxq_layers[name] = gpxq + if not self.use_quant_activations: + self.disable_quant_inference.disable_act_quantization( + self.model, is_training=self.model.training) + self.disable_quant_inference.disable_bias_quantization( + self.model, is_training=self.model.training) + + self.num_layers = len(dict_of_layers) + return self + + def __exit__(self, type, value, traceback): + self.model.forward = self.orig_forward + if not self.use_quant_activations: + self.disable_quant_inference.enable_act_quantization( + self.model, is_training=self.model.training) + self.disable_quant_inference.enable_bias_quantization( + self.model, is_training=self.model.training) + + def update(self): + for name in self.current_layer.layer_names: + self.gpxq_layers[name].single_layer_update() + self.hook_dict[name].remove() + self.current_layer.layer_names.clear() + + @abstractmethod + def catch_stopfwd(self, *args, **kwargs): + pass + + +class GPxQ(ABC): + + def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig=True) -> None: + self.layer = layer + self.name = name + self.act_order = act_order + + weight = layer.weight.data + + if create_weight_orig and not hasattr(self.layer, 'weight_orig'): + self.layer.register_buffer('weight_orig', layer.weight.detach().clone()) + + # By default, use groups = 1 + self.groups = 1 + if isinstance(self.layer, SUPPORTED_CONV_OP): + if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + weight = weight.transpose(1, 0) # This performs a view + weight = weight.flatten(1) + self.groups = self.layer.groups + + # Number of rows is equal to the output channels (OC) + self.rows = weight.shape[0] + # Number of columns is equal to the input channels (IC) + self.columns = weight.shape[1] + self.parallel_layers = parallel_layers + + self.disable_pre_forward_hook = False + # Some layers require knowledge from quant inputs to compute quant weights + self.quant_input = None + + def process_input(self, inp): + # Input is a tuple, so we take first element + inp = inp[0] + # If using Quant Activations, inp could be QuantTensor + if isinstance(inp, QuantTensor): + if self.layer.weight_quant_requires_quant_input: + # Can minimize memory allocation by not storing actual values + self.quant_input = QuantTensor( + value=None, + scale=inp.scale, + zero_point=inp.zero_point, + bit_width=inp.bit_width, + signed=inp.signed, + training=inp.training) + inp = inp.value + + # If input is unbatched, add batch_size = 1 + if len(inp.shape) == 1: + warnings.warn("Found unbatched input, adding batch dimension equal to 1") + inp = inp.unsqueeze(0) + + # Define batch size before re-organizing the input + if hasattr(inp, 'names') and 'N' in inp.names: + batch_dim = inp.names.index('N') + inp.rename_(None) + inp = inp.transpose(0, batch_dim) + return inp + + @abstractmethod + def update_batch(self): + pass + + @abstractmethod + def single_layer_update(self): + pass + + def get_quant_weights(self, i, i1, permutation_list): + # We need to recompute quant weights at runtime since our float weights are being updated + # Add offset in case of blockwise computation + i = i1 + i + # For QuantLinear and for some QuantConvolutional layers, we exploit the possibility + # of quantizing only a subset of the entire matrix speeding up the computation of GPxQ + if isinstance(self.layer, qnn.QuantLinear): + index = permutation_list[0][i] + subtensor_slice_list = [None, (index, index + 1)] + q = self.layer.quant_weight( + subtensor_slice_list=subtensor_slice_list, + quant_input=self.quant_input).value.unsqueeze(0) # [1, OC, 1] + elif isinstance(self.layer, SUPPORTED_CONV_OP): + # For depthwise and ConvTranspose we fall back to quantizing the entire martix. + # For all other cases, we create a mask that represent the slicing we will perform on the weight matrix + # and we quantize only the selected dimensions. + if self.groups > 1 or (self.groups == 1 and isinstance( + self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d))): + + quant_weight = self.layer.quant_weight(quant_input=self.quant_input) + quant_weight = quant_weight.value + + if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + quant_weight = quant_weight.transpose(1, 0) # This performs a view + quant_weight = quant_weight.flatten(1) + quant_weight = quant_weight.view(self.groups, -1, quant_weight.shape[-1]) + + if self.act_order: + for ii, perm in enumerate(permutation_list): + quant_weight[ii, :, :] = quant_weight[ii, :, perm] + + q = quant_weight[:, :, i:i + 1] # [groups, OC/groups, 1] + else: + index = permutation_list[0][i] + shapes = self.layer.weight.shape[1:] + index_2d_to_nd = [] + residual_index = index.item() + for shape in shapes[::-1]: + index_2d_to_nd.append((residual_index % shape, residual_index % shape + 1)) + residual_index = residual_index // shape + index_2d_to_nd = index_2d_to_nd[::-1] + index_2d_to_nd.insert(0, None) + q = self.layer.quant_weight( + subtensor_slice_list=index_2d_to_nd, + quant_input=self.quant_input).value.flatten(1) # [OC, 1] + q = q.unsqueeze(0) # [1, OC, 1] + # We need to remove the last dim + q = q.squeeze(2) # [groups, OC/groups] or [1, OC] + return q diff --git a/src/brevitas/nn/mixin/parameter.py b/src/brevitas/nn/mixin/parameter.py index f65621c3f..095c981f1 100644 --- a/src/brevitas/nn/mixin/parameter.py +++ b/src/brevitas/nn/mixin/parameter.py @@ -61,6 +61,9 @@ def quant_weight( self, quant_input: Optional[QuantTensor] = None, subtensor_slice_list: List[Optional[Tuple[int, int]]] = None): + weights_to_quantize = self.weight + if not self.weight_quant.is_quant_enabled and hasattr(self, 'weight_orig'): + weights_to_quantize = self.weight_orig if subtensor_slice_list is not None: # prepare the quantizer for a subtensor input, if any modifications are required # we set a list of tuples rather than a list of slices so that it's jit friendly @@ -95,9 +98,9 @@ def quant_weight( input_bit_width = None input_is_signed = None out = self.weight_quant( - self.weight[weight_slice_tuple], input_bit_width, input_is_signed) + weights_to_quantize[weight_slice_tuple], input_bit_width, input_is_signed) else: - out = self.weight_quant(self.weight[weight_slice_tuple]) + out = self.weight_quant(weights_to_quantize[weight_slice_tuple]) if subtensor_slice_list is not None: # Restore the quantizer behaviour to full tensor quantization # The modules to slice should have been cached already at this point diff --git a/src/brevitas_examples/imagenet_classification/ptq/README.md b/src/brevitas_examples/imagenet_classification/ptq/README.md index e0e7c7455..29386659b 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/README.md +++ b/src/brevitas_examples/imagenet_classification/ptq/README.md @@ -36,6 +36,7 @@ Furthermore, Brevitas additional PTQ techniques can be enabled: - If Graph equalization is enabled, the _merge\_bias_ technique can be enabled.[2 ] [3 ]. - GPTQ [4 ]. - Learned Round [5 ]. +- GPFQ [6 ]. Internally, when defining a quantized model programmatically, Brevitas leverages `torch.fx` and its `symbolic_trace` functionality, meaning that an input model is required to pass symbolic tracing for it to work. @@ -85,7 +86,8 @@ usage: ptq_evaluate.py [-h] --calibration-dir CALIBRATION_DIR --validation-dir [--bias-corr | --no-bias-corr] [--graph-eq-merge-bias | --no-graph-eq-merge-bias] [--weight-narrow-range | --no-weight-narrow-range] - [--gptq | --no-gptq] + [--gpfq-p GPFQ_P] [--gptq | --no-gptq] + [--gpfq | --no-gpfq] [--gptq-act-order | --no-gptq-act-order] [--learned-round | --no-learned-round] [--calibrate-bn | --no-calibrate-bn] @@ -171,8 +173,11 @@ optional arguments: Enable Narrow range for weight quantization (default: enabled) --no-weight-narrow-range Disable Narrow range for weight quantization (default: enabled) + --gpfq-p GPFQ_P P parameter for GPFQ (default: 0.25) --gptq Enable GPTQ (default: enabled) --no-gptq Disable GPTQ (default: enabled) + --gpfq Enable GPFQ (default: disabled) + --no-gpfq Disable GPFQ (default: disabled) --gptq-act-order Enable GPTQ Act order heuristic (default: disabled) --no-gptq-act-order Disable GPTQ Act order heuristic (default: disabled) --learned-round Enable Learned round (default: disabled) @@ -208,3 +213,4 @@ and a `RESULTS_IMGCLSMOB.csv` with the results on manually quantized models star [3 ]: https://github.com/openppl-public/ppq/blob/master/ppq/quantization/algorithm/equalization.py [4 ]: https://arxiv.org/abs/2210.17323 [5 ]: https://arxiv.org/abs/2004.10568 +[6 ]: https://arxiv.org/abs/2201.11113 diff --git a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py index 9acafbf58..9a97b4794 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py +++ b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py @@ -24,6 +24,7 @@ from brevitas.graph.target.flexml import preprocess_for_flexml_quantize from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_bias_correction +from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gpfq from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gptq from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_learned_round_learning from brevitas_examples.imagenet_classification.ptq.ptq_common import calibrate @@ -52,6 +53,7 @@ 'bias_bit_width': [32, 16], # Bias Bit-Width for Po2 scale 'weight_quant_granularity': ['per_tensor', 'per_channel'], # Scaling Per Output Channel 'act_quant_type': ['asym', 'sym'], # Act Quant Type + 'weight_param_method': ['stats', 'mse'], # Weight Quant Type 'act_param_method': ['stats', 'mse'], # Act Param Method 'bias_corr': [True], # Bias Correction 'graph_eq_iterations': [0, 20], # Graph Equalization @@ -60,6 +62,8 @@ 'learned_round': [False, True], # Enable/Disable Learned Round 'gptq': [False, True], # Enable/Disable GPTQ 'gptq_act_order': [False, True], # Use act_order euristics for GPTQ + 'gpfq': [False, True], # Enable/Disable GPFQ + 'gpfq_p': [0.25, 0.75], # GPFQ P 'act_quant_percentile': [99.9, 99.99, 99.999], # Activation Quantization Percentile } @@ -71,13 +75,16 @@ 'bias_bit_width': [32], # Bias Bit-Width for Po2 scale 'weight_quant_granularity': ['per_channel'], # Scaling Per Output Channel 'act_quant_type': ['sym'], # Act Quant Type - 'act_param_method': ['stats'], # Act Param Method + 'act_param_method': ['mse'], # Act Param Method + 'weight_param_method': ['stats'], # Weight Quant Type 'bias_corr': [True], # Bias Correction 'graph_eq_iterations': [20], # Graph Equalization 'graph_eq_merge_bias': [True], # Merge bias for Graph Equalization 'act_equalization': [None], # Perform Activation Equalization (Smoothquant) 'learned_round': [False], # Enable/Disable Learned Round 'gptq': [True], # Enable/Disable GPTQ + 'gpfq': [False], # Enable/Disable GPFQ + 'gpfq_p': [0.25], # GPFQ P 'gptq_act_order': [False], # Use act_order euristics for GPTQ 'act_quant_percentile': [99.999], # Activation Quantization Percentile } @@ -114,35 +121,36 @@ def main(): args.gpu = get_gpu_index(args.idx) print("Iter {}, GPU {}".format(args.idx, args.gpu)) - - options_names = [k.replace('_', ' ').capitalize() for k in OPTIONS.keys()] - torchvision_df = pd.DataFrame( - columns=options_names + [ - 'Top 1% floating point accuracy', - 'Top 1% quant accuracy', - 'Floating point accuracy - quant accuracy', - 'Quant accuracy / floating point accuracy', - 'Calibration size', - 'Calibration batch size', - 'Torch version', - 'Brevitas version']) try: - ptq_torchvision_models(torchvision_df, args) + ptq_torchvision_models(args) except Exception as E: print("Exception at index {}: {}".format(args.idx, E)) -def ptq_torchvision_models(df, args): +def ptq_torchvision_models(args): # Generate all possible combinations, including invalid ones # Split stats and mse due to the act_quant_percentile value - percentile_options = OPTIONS.copy() - percentile_options['act_param_method'] = ['stats'] - mse_options = OPTIONS.copy() - mse_options['act_param_method'] = ['mse'] - mse_options['act_quant_percentile'] = [None] + + if 'stats' in OPTIONS['act_param_method']: + percentile_options = OPTIONS.copy() + percentile_options['act_param_method'] = ['stats'] + else: + percentile_options = None + + if 'mse' in OPTIONS['act_param_method']: + mse_options = OPTIONS.copy() + mse_options['act_param_method'] = ['mse'] + mse_options['act_quant_percentile'] = [None] + else: + mse_options = None + + # Combine MSE and Percentile combinations, if they are defined + combinations = [] + if mse_options is not None: + combinations += list(product(*mse_options.values())) + if percentile_options is not None: + combinations += list(product(*percentile_options.values())) # Combine the two sets of combinations - combinations = list(product(*percentile_options.values())) + list( - product(*mse_options.values())) # Generate Namespace for each configuration configs = [ SimpleNamespace(**{k: v @@ -152,10 +160,12 @@ def ptq_torchvision_models(df, args): configs = list(map(validate_config, configs)) # Drop invalid configurations configs = list(config for config in configs if config.is_valid) + if args.idx > len(configs): return config_namespace = configs[args.idx] + print(config_namespace) fp_accuracy = TORCHVISION_TOP1_MAP[config_namespace.model_name] # Get model-specific configurations about input shapes and normalization @@ -212,6 +222,8 @@ def ptq_torchvision_models(df, args): backend=config_namespace.target_backend, act_bit_width=config_namespace.act_bit_width, weight_bit_width=config_namespace.weight_bit_width, + weight_param_method=config_namespace.weight_param_method, + act_param_method=config_namespace.act_param_method, bias_bit_width=config_namespace.bias_bit_width, weight_quant_granularity=config_namespace.weight_quant_granularity, act_quant_percentile=config_namespace.act_quant_percentile, @@ -228,6 +240,10 @@ def ptq_torchvision_models(df, args): print("Starting calibration") calibrate(calib_loader, quant_model) + if config_namespace.gpfq: + print("Performing GPFQ:") + apply_gpfq(calib_loader, quant_model, p=config_namespace.gpfq_p) + if config_namespace.gptq: print("Performing gptq") apply_gptq(calib_loader, quant_model, config_namespace.gptq_act_order) @@ -292,6 +308,10 @@ def validate_config(config_namespace): if not config_namespace.gptq and config_namespace.gptq_act_order: is_valid = False + # If GPFQ is disabled, we execute only one configuration for p==0.25 + if not config_namespace.gpfq and config_namespace.gpfq_p == 0.75: + is_valid = False + if config_namespace.act_equalization == 'layerwise' and config_namespace.target_backend == 'fx': is_valid = False if config_namespace.act_bit_width < config_namespace.weight_bit_width: diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 7982059f8..f2ae5092c 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -13,6 +13,7 @@ from brevitas.graph.calibrate import calibration_mode from brevitas.graph.calibrate import norm_correction_mode from brevitas.graph.equalize import activation_equalization_mode +from brevitas.graph.gpfq import gpfq_mode from brevitas.graph.gptq import gptq_mode from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.quantize import quantize @@ -203,7 +204,7 @@ def kwargs_prefix(prefix, weight_kwargs): weight_quant = weight_quant.let(zero_point_impl=ParameterFromStatsFromParameterZeroPoint) if act_quant is not None: act_quant = act_quant.let(**{'high_percentile_q': act_quant_percentile, 'dtype': dtype}) - if act_quant_type == 'asym': + if act_quant_type == 'asym' and act_quant_percentile is not None: act_quant = act_quant.let(**{'low_percentile_q': 100 - act_quant_percentile}) if sym_act_quant is not None: sym_act_quant = sym_act_quant.let( @@ -213,7 +214,7 @@ def kwargs_prefix(prefix, weight_kwargs): per_tensor_act_quant = per_tensor_act_quant.let( **{ 'high_percentile_q': act_quant_percentile, 'dtype': dtype}) - if act_quant_type == 'asym': + if act_quant_type == 'asym' and act_quant_percentile is not None: per_tensor_act_quant = per_tensor_act_quant.let( **{'low_percentile_q': 100 - act_quant_percentile}) @@ -360,6 +361,21 @@ def apply_gptq(calib_loader, model, act_order=False): gptq.update() +def apply_gpfq(calib_loader, model, p=0.25): + model.eval() + dtype = next(model.parameters()).dtype + device = next(model.parameters()).device + with torch.no_grad(): + with gpfq_mode(model, p=p, use_quant_activations=True) as gpfq: + gpfq_model = gpfq.model + for i in tqdm(range(gpfq.num_layers)): + for i, (images, target) in enumerate(calib_loader): + images = images.to(device) + images = images.to(dtype) + gpfq_model(images) + gpfq.update() + + def apply_learned_round_learning( model, dataloader, optimizer_class=torch.optim.Adam, iters=1000, optimizer_lr=1e-1): layers = [] diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index fdf2b966c..a560cd4c3 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -22,6 +22,7 @@ from brevitas.graph.target.flexml import preprocess_for_flexml_quantize from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_bias_correction +from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gpfq from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gptq from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_learned_round_learning from brevitas_examples.imagenet_classification.ptq.ptq_common import calibrate @@ -165,7 +166,10 @@ 'weight-narrow-range', default=True, help='Narrow range for weight quantization (default: enabled)') +parser.add_argument( + '--gpfq-p', default=0.25, type=float, help='P parameter for GPFQ (default: 0.25)') add_bool_arg(parser, 'gptq', default=True, help='GPTQ (default: enabled)') +add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)') add_bool_arg( parser, 'gptq-act-order', default=False, help='GPTQ Act order heuristic (default: disabled)') add_bool_arg(parser, 'learned-round', default=False, help='Learned round (default: disabled)') @@ -191,6 +195,7 @@ def main(): f"a{args.act_bit_width}" f"w{args.weight_bit_width}_" f"{'gptq_' if args.gptq else ''}" + f"{'gpfq_' if args.gpfq else ''}" f"{'gptq_act_order_' if args.gptq_act_order else ''}" f"{'learned_round_' if args.learned_round else ''}" f"{'weight_narrow_range_' if args.weight_narrow_range else ''}" @@ -211,6 +216,8 @@ def main(): f"Activation bit width: {args.act_bit_width} - " f"Weight bit width: {args.weight_bit_width} - " f"GPTQ: {args.gptq} - " + f"GPFQ: {args.gpfq} - " + f"GPFQ P: {args.gpfq_p} - " f"GPTQ Act Order: {args.gptq_act_order} - " f"Learned Round: {args.learned_round} - " f"Weight narrow range: {args.weight_narrow_range} - " @@ -299,9 +306,13 @@ def main(): print("Starting activation calibration:") calibrate(calib_loader, quant_model) + if args.gpfq: + print("Performing GPFQ:") + apply_gpfq(calib_loader, quant_model, p=args.gpfq_p) + if args.gptq: print("Performing GPTQ:") - apply_gptq(calib_loader, quant_model, args.gptq_act_order) + apply_gptq(calib_loader, quant_model, act_order=args.gptq_act_order) if args.learned_round: print("Applying Learned Round:")