diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py
index d206e016a..fddbfd892 100644
--- a/src/brevitas/graph/calibrate.py
+++ b/src/brevitas/graph/calibrate.py
@@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause
from abc import ABC
+from copy import deepcopy
from functools import partial
import sys
@@ -279,6 +280,7 @@ def forward_hook_wbiol(self, module, inp, output, name):
# Compute float reference
self.disable_act_quantization(module, is_training=False)
self.disable_param_quantization(module, is_training=False)
+
out_float = module.forward(*inp) # Required to avoid infinite recursion
self.collect_float_mean(module, out_float, name)
self.enable_act_quantization(module, is_training=False)
diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py
new file mode 100644
index 000000000..01dc11a82
--- /dev/null
+++ b/src/brevitas/graph/gpfq.py
@@ -0,0 +1,229 @@
+# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
+# SPDX-License-Identifier: BSD-3-Clause
+
+from copy import deepcopy
+from typing import List, Optional
+
+import numpy as np
+import torch
+import unfoldNd
+
+from brevitas.graph.gpxq import GPxQ
+from brevitas.graph.gpxq import gpxq_mode
+from brevitas.graph.gpxq import StopFwdException
+from brevitas.graph.gpxq import SUPPORTED_CONV_OP
+import brevitas.nn as qnn
+
+
+class gpfq_mode(gpxq_mode):
+ """
+ Apply GPFQ algorithm.
+
+ Args:
+ model (Module): The model to quantize with GPFQ
+ inplace (bool): Wheter to apply GPFQ inplace or perform a deepcopy. Default: True
+ use_quant_activations (bool): Wheter to leave quantize activations enabled while performing
+ GPFQ. Default: False
+
+ Example:
+ >>> with torch.no_grad():
+ >>> with gpfq_mode(model) as gpfq:
+ >>> gpfq_model = gpfq.model
+ >>> for i in tqdm(range(gpfq.num_layers)):
+ >>> for img, t in calib_loader:
+ >>> img = img.cuda()
+ >>> gpfq_model(img)
+ >>> gpfq.update()
+ """
+
+ def __init__(
+ self,
+ model,
+ group_of_parallel_layers: Optional[List[str]] = None,
+ inplace: bool = True,
+ create_weight_orig: bool = True,
+ use_quant_activations: bool = True,
+ p: int = 0.25,
+ return_forward_output: bool = False,
+ act_order: bool = False) -> None:
+ if not inplace:
+ model = deepcopy(model)
+ super().__init__(
+ model,
+ group_of_parallel_layers,
+ inplace,
+ create_weight_orig,
+ use_quant_activations,
+ act_order,
+ return_forward_output)
+
+ self.orig_forward = self.model.forward
+ self.model.forward = self.catch_stopfwd
+ self.class_implementation = GPFQ
+ GPFQ.p = p
+
+ def catch_stopfwd(self, *args, **kwargs):
+ # Collect quant input
+ try:
+ self.orig_forward(*args, **kwargs)
+ except StopFwdException:
+ pass
+
+ # Disable quantization
+ self.disable_quant_inference.disable_param_quantization(self.model, is_training=False)
+ self.disable_quant_inference.disable_act_quantization(self.model, is_training=False)
+ # Collect float input
+ try:
+ self.orig_forward(*args, **kwargs)
+ except StopFwdException:
+ pass
+
+ # Re-enable quantization. If activation quantization is disabled,
+ # we also disable bias quantization
+ self.disable_quant_inference.enable_param_quantization(self.model, is_training=False)
+ if self.use_quant_activations:
+ self.disable_quant_inference.enable_act_quantization(self.model, is_training=False)
+ else:
+ self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False)
+
+ if self.return_forward_output:
+ # If we want to return the output of the network, we need to disable all hooks
+ for name, gpxq_class in self.gpxq_layers.items():
+ gpxq_class.disable_pre_forward_hook = True
+ out = self.orig_forward(*args, **kwargs)
+ for name, gpxq_class in self.gpxq_layers.items():
+ gpxq_class.disable_pre_forward_hook = False
+ return out
+
+
+class GPFQ(GPxQ):
+ """
+ Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main
+ """
+ p = 0.25
+
+ def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig=True) -> None:
+
+ if act_order:
+ raise ValueError("Act_order is not supported in GPFQ")
+
+ super().__init__(layer, name, act_order, parallel_layers, create_weight_orig)
+ self.float_input = None
+ self.quantized_input = None
+ self.index_computed = False
+ self.p = GPFQ.p
+
+ def update_batch(self, module, input, current_layer):
+ if self.disable_pre_forward_hook:
+ return input
+
+ # Update reference to current layer
+ current_layer.layer_names.add(self.name)
+ is_quant_disabled = module.weight_quant.disable_quant
+
+ inp = self.process_input(input)
+ batch_size = inp.shape[0]
+
+ # Preprocess the input to compute the Hessian
+ if isinstance(self.layer, qnn.QuantLinear):
+ if len(inp.shape) > 2:
+ inp = inp.reshape((-1, sum(inp.shape[2:])))
+ # For QuantLinear layer, groups will be 1
+ inp_processed = inp.unsqueeze(0)
+
+ if isinstance(self.layer, SUPPORTED_CONV_OP):
+ # Pick the correct unfoldNd class
+ if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
+ unfold_impl = unfoldNd.UnfoldTransposeNd
+ else:
+ unfold_impl = unfoldNd.UnfoldNd
+
+ unfold = unfold_impl(
+ self.layer.kernel_size,
+ dilation=self.layer.dilation,
+ padding=self.layer.padding,
+ stride=self.layer.kernel_size)
+
+ # Split input based on how many groups in convolution
+ inp_by_group = torch.chunk(inp, self.groups, 1)
+ inp_processed = []
+ # Preprocess input by group
+ for i, inp in enumerate(inp_by_group):
+
+ inp = unfold(inp)
+
+ batch_size, num_blocks = inp.shape[0], inp.shape[-1]
+ inp = torch.transpose(inp, 1, 2) # shape (B, L, C*kernel_size[0]*kernel_size[1])
+ inp = inp.reshape(-1, inp.size(-1)) # shape (B*L, C*kernel_size[0]*kernel_size[1])
+
+ if not self.index_computed:
+ self.index_computed = True
+ self.rand_indices = np.concatenate([
+ np.random.choice(
+ np.arange(num_blocks * i, num_blocks * (i + 1)),
+ size=int(
+ self.p * num_blocks + 1 if self.p != 1 else self.p * num_blocks))
+ for i in range(batch_size)]) # need to define self.p (probability)
+
+ indexes = self.rand_indices
+ if np.max(self.rand_indices) > inp.shape[0]:
+ indexes = self.rand_indices < inp.shape[0]
+ indexes = self.rand_indices[indexes]
+
+ inp = inp[indexes]
+ inp_processed.append(inp)
+ inp_processed = torch.stack(inp_processed)
+
+ if is_quant_disabled:
+ if self.float_input is None:
+ self.float_input = inp_processed
+ else:
+ self.float_input = torch.cat([self.float_input, inp_processed], dim=1)
+ else:
+ if self.quantized_input is None:
+ self.quantized_input = inp_processed
+ else:
+ self.quantized_input = torch.cat([self.quantized_input, inp_processed], dim=1)
+ # If we are executing GPFQ with group of parallel layers, we keep track of how many forward
+ # we executed. Once we executed as many as the number of parallel_layers, we raise
+ # StopFwdException
+ current_layer.forward_count += 1
+ if current_layer.forward_count == len(self.parallel_layers):
+ current_layer.forward_count = 0
+ raise StopFwdException
+
+ def single_layer_update(self):
+ weight = self.layer.weight.data
+ dev = weight.device
+ dtype = weight.dtype
+ if isinstance(self.layer, SUPPORTED_CONV_OP):
+ if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
+ weight = weight.transpose(1, 0) # This performs a view
+ weight = weight.flatten(1)
+ weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC]
+ U = torch.zeros(
+ weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype)
+ self.float_input = self.float_input.to(dev)
+ self.quantized_input = self.quantized_input.to(dev)
+ permutation_list = [torch.tensor(range(weight.shape[-1]))]
+ for t in range(weight.shape[-1]):
+ for group_index in range(self.groups):
+ U[group_index] += torch.matmul(
+ weight[group_index, :, t].unsqueeze(1),
+ self.float_input[group_index, :,
+ t].unsqueeze(0)) #[OC/Groups, 1] * [1, INSHAPE[1]]
+ norm = torch.linalg.norm(self.quantized_input[group_index, :, t], 2) ** 2
+ if norm > 0:
+ q_arg = U[group_index].matmul(self.quantized_input[group_index, :, t]) / norm
+ else:
+ q_arg = torch.zeros_like(U[group_index, :, 0])
+
+ weight[group_index, :, t] = q_arg
+ q = self.get_quant_weights(t, 0, permutation_list)
+ for group_index in range(self.groups):
+ U[group_index] -= torch.matmul(
+ q[group_index].unsqueeze(1),
+ self.quantized_input[group_index, :, t].unsqueeze(0))
+
+ del self.float_input
+ del self.quantized_input
diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py
index 8f8ffb6ae..b224e0a37 100644
--- a/src/brevitas/graph/gptq.py
+++ b/src/brevitas/graph/gptq.py
@@ -2,11 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause
from copy import deepcopy
-from dataclasses import dataclass
-from dataclasses import field
-from functools import partial
import math
-from operator import attrgetter
from typing import List, Optional, Set
import warnings
@@ -16,28 +12,16 @@
from torch.linalg import LinAlgError
except:
LinAlgError = RuntimeError
-
import unfoldNd
-from brevitas.graph.calibrate import DisableEnableQuantization
+from brevitas.graph.gpxq import GPxQ
+from brevitas.graph.gpxq import gpxq_mode
+from brevitas.graph.gpxq import StopFwdException
+from brevitas.graph.gpxq import SUPPORTED_CONV_OP
import brevitas.nn as qnn
-from brevitas.quant_tensor import QuantTensor
-
-SUPPORTED_CONV_OP = (
- qnn.QuantConv2d, qnn.QuantConv1d, qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)
-
-
-class StopFwdException(Exception):
- pass
-@dataclass
-class LayerHandler:
- layer_names: Set = field(default_factory=set)
- forward_count: int = 0
-
-
-class gptq_mode:
+class gptq_mode(gpxq_mode):
"""
Apply GPTQ algorithm https://arxiv.org/abs/2210.17323.
@@ -63,99 +47,28 @@ def __init__(
model,
group_of_parallel_layers: Optional[List[str]] = None,
inplace: bool = True,
+ create_weight_orig: bool = True,
use_quant_activations: bool = True,
num_blocks: int = 100,
- act_order: bool = False,
- return_forward_output: bool = False) -> None:
+ return_forward_output: bool = False,
+ act_order: bool = False) -> None:
if not inplace:
model = deepcopy(model)
- self.model = model
- self.use_quant_activations = use_quant_activations
- self.hook_dict = dict()
- self.gptq_layers = dict()
- # reference for each layer to update
- self.current_layer = LayerHandler()
- # How many layer to optimize
- self.num_layers = 0
- # Quantize following magnitude of activation
- self.act_order = act_order
- # How many subblock to use during GPTQ for each layer
- self.num_blocks = num_blocks
+ super().__init__(
+ model,
+ group_of_parallel_layers,
+ inplace,
+ create_weight_orig,
+ use_quant_activations,
+ act_order,
+ return_forward_output)
- self.disable_quant_inference = DisableEnableQuantization()
self.orig_forward = self.model.forward
self.model.forward = self.catch_stopfwd
- self.group_of_parallel_layers = group_of_parallel_layers
- self.return_forward_output = return_forward_output
-
- def _is_module_supported(self, module):
- if isinstance(module, SUPPORTED_CONV_OP):
- return True
- elif isinstance(module, qnn.QuantLinear):
- return True
- else:
- return False
-
- def __enter__(self):
- # The user can specify on which layers to apply gptq in parallel.
- # All the others will be executed sequentially
- dict_of_layers = {
- name: [(name, module)] for name,
- module in self.model.named_modules() if self._is_module_supported(module)}
- if self.group_of_parallel_layers is not None:
- for parallel_layers in self.group_of_parallel_layers:
- for name in parallel_layers:
- if name not in dict_of_layers:
- raise ValueError(
- "The layer {} is not present in the model or it is not supported for GPTQ"
- .format(name))
- del dict_of_layers[name]
- names = '_'.join(parallel_layers)
- dict_of_layers[names] = [
- (name, attrgetter(name)(self.model)) for name in parallel_layers]
-
- # Print warning if hooks are attached to any module, since the normal forward flow of the
- # network is highly disrupted during GPTQ
- for _, parallel_layers in dict_of_layers.items():
- for name, module in parallel_layers:
- if len(module._forward_hooks) > 0 or len(module._forward_pre_hooks):
- warnings.warn(
- f'Hooks detected during setup for GPTQ. '
- f'Behaviour might deviate from what expected.')
-
- # Attach hooks for GPTQ
- if self._is_module_supported(module):
- gptq = GPTQ(
- module,
- name,
- num_blocks=self.num_blocks,
- act_order=self.act_order,
- parallel_layers=parallel_layers)
- hook_fn = partial(gptq.update_batch, current_layer=self.current_layer)
- self.hook_dict[name] = module.register_forward_pre_hook(hook_fn)
- self.gptq_layers[name] = gptq
- if not self.use_quant_activations:
- self.disable_quant_inference.disable_act_quantization(
- self.model, is_training=self.model.training)
- self.disable_quant_inference.disable_bias_quantization(
- self.model, is_training=self.model.training)
-
- self.num_layers = len(dict_of_layers)
- return self
-
- def __exit__(self, type, value, traceback):
- self.model.forward = self.orig_forward
- if not self.use_quant_activations:
- self.disable_quant_inference.enable_act_quantization(
- self.model, is_training=self.model.training)
- self.disable_quant_inference.enable_bias_quantization(
- self.model, is_training=self.model.training)
-
- def update(self):
- for name in self.current_layer.layer_names:
- self.gptq_layers[name].single_layer_update()
- self.hook_dict[name].remove()
- self.current_layer.layer_names.clear()
+ # How many subblock to use during GPTQ for each layer
+ self.num_blocks = num_blocks
+ self.class_implementation = GPTQ
+ GPTQ.num_blocks = num_blocks
def catch_stopfwd(self, *args, **kwargs):
try:
@@ -165,15 +78,15 @@ def catch_stopfwd(self, *args, **kwargs):
finally:
if self.return_forward_output:
# If we want to return the output of the network, we need to disable all hooks
- for name, gptq_class in self.gptq_layers.items():
- gptq_class.disable_pre_forward_hook = True
+ for name, gpxq_class in self.gpxq_layers.items():
+ gpxq_class.disable_pre_forward_hook = True
out = self.orig_forward(*args, **kwargs)
- for name, gptq_class in self.gptq_layers.items():
- gptq_class.disable_pre_forward_hook = False
+ for name, gpxq_class in self.gpxq_layers.items():
+ gpxq_class.disable_pre_forward_hook = False
return out
-class GPTQ():
+class GPTQ(GPxQ):
"""
Adapted from https://github.com/IST-DASLab/gptq, released under the following LICENSE:
@@ -191,74 +104,29 @@ class GPTQ():
See the License for the specific language governing permissions and
limitations under the License.
"""
+ num_blocks = 100
- def __init__(self, layer, name, num_blocks, act_order, parallel_layers=1) -> None:
- self.layer = layer
- self.name = name
- self.num_blocks = num_blocks
- self.act_order = act_order
+ def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig=True) -> None:
+ super().__init__(layer, name, act_order, parallel_layers, create_weight_orig)
- weight = layer.weight.data
- dev = weight.device
-
- # By default, use groups = 1
- self.groups = 1
- if isinstance(self.layer, SUPPORTED_CONV_OP):
- if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
- weight = weight.transpose(1, 0) # This performs a view
- weight = weight.flatten(1)
- self.groups = self.layer.groups
-
- # Number of rows is equal to the output channels (OC)
- self.rows = weight.shape[0]
- # Number of columns is equal to the input channels (IC)
- self.columns = weight.shape[1]
+ dev = self.layer.weight.device
# Define how many columns to update in each mini-block
- self.blocksize = math.ceil(self.columns / self.num_blocks)
+ self.blocksize = math.ceil(self.columns / GPTQ.num_blocks)
# Initialize Hessian matrix and counter. We need it in float32 to compute the inverse
self.H = torch.zeros((self.groups, self.columns, self.columns),
device=dev,
dtype=torch.float32)
self.nsamples = 0
- self.parallel_layers = parallel_layers
-
- self.disable_pre_forward_hook = False
- # Some layers require knowledge from quant inputs to compute quant weights
- self.quant_input = None
def update_batch(self, module, input, current_layer):
if self.disable_pre_forward_hook:
return input
+
# Update reference to current layer
current_layer.layer_names.add(self.name)
-
- # Input is a tuple, so we take first element
- inp = input[0]
- # If using Quant Activations, inp could be QuantTensor
- if isinstance(inp, QuantTensor):
- if self.layer.weight_quant_requires_quant_input:
- # Can minimize memory allocation by not storing actual values
- self.quant_input = QuantTensor(
- value=None,
- scale=inp.scale,
- zero_point=inp.zero_point,
- bit_width=inp.bit_width,
- signed=inp.signed,
- training=inp.training)
- inp = inp.value
-
- # If input is unbatched, add batch_size = 1
- if len(inp.shape) == 1:
- warnings.warn("Found unbatched input, adding batch dimension equal to 1")
- inp = inp.unsqueeze(0)
-
- # Define batch size before re-organizing the input
- if hasattr(inp, 'names') and 'N' in inp.names:
- batch_dim = inp.names.index('N')
- inp.rename_(None)
- inp = inp.transpose(0, batch_dim)
+ inp = self.process_input(input)
batch_size = inp.shape[0]
# Preprocess the input to compute the Hessian
@@ -390,53 +258,3 @@ def single_layer_update(self, percdamp=.01):
perm = permutation_list[group_index]
weight[group_index, :, perm[i2:]] -= (
error_block[group_index].matmul(h_inv[group_index, i1:i2, i2:])).to(dtype)
-
- def get_quant_weights(self, i, i1, permutation_list):
- # We need to recompute quant weights at runtime since our float weights are being updated
- # Add offset in case of blockwise computation (e.g., GPTQ)
- i = i1 + i
- # For QuantLinear and for some QuantConvolutional layers, we exploit the possibility
- # of quantizing only a subset of the entire matrix speeding up the computation of GPTQ
- if isinstance(self.layer, qnn.QuantLinear):
- index = permutation_list[0][i]
- subtensor_slice_list = [None, (index, index + 1)]
- q = self.layer.quant_weight(
- subtensor_slice_list=subtensor_slice_list,
- quant_input=self.quant_input).value.unsqueeze(0) # [1, OC, 1]
- elif isinstance(self.layer, SUPPORTED_CONV_OP):
- # For depthwise and ConvTranspose we fall back to quantizing the entire martix.
- # For all other cases, we create a mask that represent the slicing we will perform on the weight matrix
- # and we quantize only the selected dimensions.
- if self.groups > 1 or (self.groups == 1 and isinstance(
- self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d))):
-
- quant_weight = self.layer.quant_weight(quant_input=self.quant_input)
- quant_weight = quant_weight.value
-
- if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
- quant_weight = quant_weight.transpose(1, 0) # This performs a view
- quant_weight = quant_weight.flatten(1)
- quant_weight = quant_weight.view(self.groups, -1, quant_weight.shape[-1])
-
- if self.act_order:
- for ii, perm in enumerate(permutation_list):
- quant_weight[ii, :, :] = quant_weight[ii, :, perm]
-
- q = quant_weight[:, :, i:i + 1] # [groups, OC/groups, 1]
- else:
- index = permutation_list[0][i]
- shapes = self.layer.weight.shape[1:]
- index_2d_to_nd = []
- residual_index = index.item()
- for shape in shapes[::-1]:
- index_2d_to_nd.append((residual_index % shape, residual_index % shape + 1))
- residual_index = residual_index // shape
- index_2d_to_nd = index_2d_to_nd[::-1]
- index_2d_to_nd.insert(0, None)
- q = self.layer.quant_weight(
- subtensor_slice_list=index_2d_to_nd,
- quant_input=self.quant_input).value.flatten(1) # [OC, 1]
- q = q.unsqueeze(0) # [1, OC, 1]
- # We need to remove the last dim
- q = q.squeeze(2) # [groups, OC/groups] or [1, OC]
- return q
diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py
new file mode 100644
index 000000000..b13c46683
--- /dev/null
+++ b/src/brevitas/graph/gpxq.py
@@ -0,0 +1,252 @@
+# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
+# SPDX-License-Identifier: BSD-3-Clause
+
+from abc import ABC
+from abc import abstractmethod
+from copy import deepcopy
+from dataclasses import dataclass
+from dataclasses import field
+from functools import partial
+from operator import attrgetter
+from typing import List, Optional, Set
+import warnings
+
+from brevitas.graph.calibrate import DisableEnableQuantization
+import brevitas.nn as qnn
+from brevitas.quant_tensor import QuantTensor
+
+SUPPORTED_CONV_OP = (
+ qnn.QuantConv2d, qnn.QuantConv1d, qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)
+
+
+class StopFwdException(Exception):
+ pass
+
+
+@dataclass
+class LayerHandler:
+ layer_names: Set = field(default_factory=set)
+ forward_count: int = 0
+
+
+class gpxq_mode(ABC):
+
+ def __init__(
+ self,
+ model,
+ group_of_parallel_layers: Optional[List[str]] = None,
+ inplace: bool = True,
+ create_weight_orig: bool = True,
+ use_quant_activations: bool = True,
+ act_order: bool = False,
+ return_forward_output: bool = False) -> None:
+
+ if not inplace:
+ model = deepcopy(model)
+ self.model = model
+ self.create_weight_orig = create_weight_orig
+ self.use_quant_activations = use_quant_activations
+ self.hook_dict = dict()
+ self.gpxq_layers = dict()
+ # reference for each layer to update
+ self.current_layer = LayerHandler()
+ # How many layer to optimize
+ self.num_layers = 0
+ # Quantize following magnitude of activation
+ self.act_order = act_order
+ # How many subblock to use during GPTQ for each layer
+
+ self.disable_quant_inference = DisableEnableQuantization()
+
+ self.group_of_parallel_layers = group_of_parallel_layers
+ self.return_forward_output = return_forward_output
+
+ def _is_module_supported(self, module):
+ if isinstance(module, SUPPORTED_CONV_OP):
+ return True
+ elif isinstance(module, qnn.QuantLinear):
+ return True
+ else:
+ return False
+
+ def __enter__(self):
+ # The user can specify on which layers to apply gptq in parallel.
+ # All the others will be executed sequentially
+ dict_of_layers = {
+ name: [(name, module)] for name,
+ module in self.model.named_modules() if self._is_module_supported(module)}
+ if self.group_of_parallel_layers is not None:
+ for parallel_layers in self.group_of_parallel_layers:
+ for name in parallel_layers:
+ if name not in dict_of_layers:
+ raise ValueError(
+ "The layer {} is not present in the model or it is not supported for GPTQ"
+ .format(name))
+ del dict_of_layers[name]
+ names = '_'.join(parallel_layers)
+ dict_of_layers[names] = [
+ (name, attrgetter(name)(self.model)) for name in parallel_layers]
+
+ # Print warning if hooks are attached to any module, since the normal forward flow of the
+ # network is highly disrupted during GPxQ
+ for _, parallel_layers in dict_of_layers.items():
+ for name, module in parallel_layers:
+ if len(module._forward_hooks) > 0 or len(module._forward_pre_hooks):
+ warnings.warn(
+ f'Hooks detected during setup for GPxQ. '
+ f'Behaviour might deviate from what expected.')
+
+ # Attach hooks for GPTQ
+ if self._is_module_supported(module):
+ gpxq = self.class_implementation(
+ module,
+ name,
+ act_order=self.act_order,
+ parallel_layers=parallel_layers,
+ create_weight_orig=self.create_weight_orig)
+ hook_fn = partial(gpxq.update_batch, current_layer=self.current_layer)
+ self.hook_dict[name] = module.register_forward_pre_hook(hook_fn)
+ self.gpxq_layers[name] = gpxq
+ if not self.use_quant_activations:
+ self.disable_quant_inference.disable_act_quantization(
+ self.model, is_training=self.model.training)
+ self.disable_quant_inference.disable_bias_quantization(
+ self.model, is_training=self.model.training)
+
+ self.num_layers = len(dict_of_layers)
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.model.forward = self.orig_forward
+ if not self.use_quant_activations:
+ self.disable_quant_inference.enable_act_quantization(
+ self.model, is_training=self.model.training)
+ self.disable_quant_inference.enable_bias_quantization(
+ self.model, is_training=self.model.training)
+
+ def update(self):
+ for name in self.current_layer.layer_names:
+ self.gpxq_layers[name].single_layer_update()
+ self.hook_dict[name].remove()
+ self.current_layer.layer_names.clear()
+
+ @abstractmethod
+ def catch_stopfwd(self, *args, **kwargs):
+ pass
+
+
+class GPxQ(ABC):
+
+ def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig=True) -> None:
+ self.layer = layer
+ self.name = name
+ self.act_order = act_order
+
+ weight = layer.weight.data
+
+ if create_weight_orig and not hasattr(self.layer, 'weight_orig'):
+ self.layer.register_buffer('weight_orig', layer.weight.detach().clone())
+
+ # By default, use groups = 1
+ self.groups = 1
+ if isinstance(self.layer, SUPPORTED_CONV_OP):
+ if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
+ weight = weight.transpose(1, 0) # This performs a view
+ weight = weight.flatten(1)
+ self.groups = self.layer.groups
+
+ # Number of rows is equal to the output channels (OC)
+ self.rows = weight.shape[0]
+ # Number of columns is equal to the input channels (IC)
+ self.columns = weight.shape[1]
+ self.parallel_layers = parallel_layers
+
+ self.disable_pre_forward_hook = False
+ # Some layers require knowledge from quant inputs to compute quant weights
+ self.quant_input = None
+
+ def process_input(self, inp):
+ # Input is a tuple, so we take first element
+ inp = inp[0]
+ # If using Quant Activations, inp could be QuantTensor
+ if isinstance(inp, QuantTensor):
+ if self.layer.weight_quant_requires_quant_input:
+ # Can minimize memory allocation by not storing actual values
+ self.quant_input = QuantTensor(
+ value=None,
+ scale=inp.scale,
+ zero_point=inp.zero_point,
+ bit_width=inp.bit_width,
+ signed=inp.signed,
+ training=inp.training)
+ inp = inp.value
+
+ # If input is unbatched, add batch_size = 1
+ if len(inp.shape) == 1:
+ warnings.warn("Found unbatched input, adding batch dimension equal to 1")
+ inp = inp.unsqueeze(0)
+
+ # Define batch size before re-organizing the input
+ if hasattr(inp, 'names') and 'N' in inp.names:
+ batch_dim = inp.names.index('N')
+ inp.rename_(None)
+ inp = inp.transpose(0, batch_dim)
+ return inp
+
+ @abstractmethod
+ def update_batch(self):
+ pass
+
+ @abstractmethod
+ def single_layer_update(self):
+ pass
+
+ def get_quant_weights(self, i, i1, permutation_list):
+ # We need to recompute quant weights at runtime since our float weights are being updated
+ # Add offset in case of blockwise computation
+ i = i1 + i
+ # For QuantLinear and for some QuantConvolutional layers, we exploit the possibility
+ # of quantizing only a subset of the entire matrix speeding up the computation of GPxQ
+ if isinstance(self.layer, qnn.QuantLinear):
+ index = permutation_list[0][i]
+ subtensor_slice_list = [None, (index, index + 1)]
+ q = self.layer.quant_weight(
+ subtensor_slice_list=subtensor_slice_list,
+ quant_input=self.quant_input).value.unsqueeze(0) # [1, OC, 1]
+ elif isinstance(self.layer, SUPPORTED_CONV_OP):
+ # For depthwise and ConvTranspose we fall back to quantizing the entire martix.
+ # For all other cases, we create a mask that represent the slicing we will perform on the weight matrix
+ # and we quantize only the selected dimensions.
+ if self.groups > 1 or (self.groups == 1 and isinstance(
+ self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d))):
+
+ quant_weight = self.layer.quant_weight(quant_input=self.quant_input)
+ quant_weight = quant_weight.value
+
+ if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
+ quant_weight = quant_weight.transpose(1, 0) # This performs a view
+ quant_weight = quant_weight.flatten(1)
+ quant_weight = quant_weight.view(self.groups, -1, quant_weight.shape[-1])
+
+ if self.act_order:
+ for ii, perm in enumerate(permutation_list):
+ quant_weight[ii, :, :] = quant_weight[ii, :, perm]
+
+ q = quant_weight[:, :, i:i + 1] # [groups, OC/groups, 1]
+ else:
+ index = permutation_list[0][i]
+ shapes = self.layer.weight.shape[1:]
+ index_2d_to_nd = []
+ residual_index = index.item()
+ for shape in shapes[::-1]:
+ index_2d_to_nd.append((residual_index % shape, residual_index % shape + 1))
+ residual_index = residual_index // shape
+ index_2d_to_nd = index_2d_to_nd[::-1]
+ index_2d_to_nd.insert(0, None)
+ q = self.layer.quant_weight(
+ subtensor_slice_list=index_2d_to_nd,
+ quant_input=self.quant_input).value.flatten(1) # [OC, 1]
+ q = q.unsqueeze(0) # [1, OC, 1]
+ # We need to remove the last dim
+ q = q.squeeze(2) # [groups, OC/groups] or [1, OC]
+ return q
diff --git a/src/brevitas/nn/mixin/parameter.py b/src/brevitas/nn/mixin/parameter.py
index f65621c3f..095c981f1 100644
--- a/src/brevitas/nn/mixin/parameter.py
+++ b/src/brevitas/nn/mixin/parameter.py
@@ -61,6 +61,9 @@ def quant_weight(
self,
quant_input: Optional[QuantTensor] = None,
subtensor_slice_list: List[Optional[Tuple[int, int]]] = None):
+ weights_to_quantize = self.weight
+ if not self.weight_quant.is_quant_enabled and hasattr(self, 'weight_orig'):
+ weights_to_quantize = self.weight_orig
if subtensor_slice_list is not None:
# prepare the quantizer for a subtensor input, if any modifications are required
# we set a list of tuples rather than a list of slices so that it's jit friendly
@@ -95,9 +98,9 @@ def quant_weight(
input_bit_width = None
input_is_signed = None
out = self.weight_quant(
- self.weight[weight_slice_tuple], input_bit_width, input_is_signed)
+ weights_to_quantize[weight_slice_tuple], input_bit_width, input_is_signed)
else:
- out = self.weight_quant(self.weight[weight_slice_tuple])
+ out = self.weight_quant(weights_to_quantize[weight_slice_tuple])
if subtensor_slice_list is not None:
# Restore the quantizer behaviour to full tensor quantization
# The modules to slice should have been cached already at this point
diff --git a/src/brevitas_examples/imagenet_classification/ptq/README.md b/src/brevitas_examples/imagenet_classification/ptq/README.md
index e0e7c7455..29386659b 100644
--- a/src/brevitas_examples/imagenet_classification/ptq/README.md
+++ b/src/brevitas_examples/imagenet_classification/ptq/README.md
@@ -36,6 +36,7 @@ Furthermore, Brevitas additional PTQ techniques can be enabled:
- If Graph equalization is enabled, the _merge\_bias_ technique can be enabled.[2 ] [3 ].
- GPTQ [4 ].
- Learned Round [5 ].
+- GPFQ [6 ].
Internally, when defining a quantized model programmatically, Brevitas leverages `torch.fx` and its `symbolic_trace` functionality, meaning that an input model is required to pass symbolic tracing for it to work.
@@ -85,7 +86,8 @@ usage: ptq_evaluate.py [-h] --calibration-dir CALIBRATION_DIR --validation-dir
[--bias-corr | --no-bias-corr]
[--graph-eq-merge-bias | --no-graph-eq-merge-bias]
[--weight-narrow-range | --no-weight-narrow-range]
- [--gptq | --no-gptq]
+ [--gpfq-p GPFQ_P] [--gptq | --no-gptq]
+ [--gpfq | --no-gpfq]
[--gptq-act-order | --no-gptq-act-order]
[--learned-round | --no-learned-round]
[--calibrate-bn | --no-calibrate-bn]
@@ -171,8 +173,11 @@ optional arguments:
Enable Narrow range for weight quantization (default: enabled)
--no-weight-narrow-range
Disable Narrow range for weight quantization (default: enabled)
+ --gpfq-p GPFQ_P P parameter for GPFQ (default: 0.25)
--gptq Enable GPTQ (default: enabled)
--no-gptq Disable GPTQ (default: enabled)
+ --gpfq Enable GPFQ (default: disabled)
+ --no-gpfq Disable GPFQ (default: disabled)
--gptq-act-order Enable GPTQ Act order heuristic (default: disabled)
--no-gptq-act-order Disable GPTQ Act order heuristic (default: disabled)
--learned-round Enable Learned round (default: disabled)
@@ -208,3 +213,4 @@ and a `RESULTS_IMGCLSMOB.csv` with the results on manually quantized models star
[3 ]: https://github.com/openppl-public/ppq/blob/master/ppq/quantization/algorithm/equalization.py
[4 ]: https://arxiv.org/abs/2210.17323
[5 ]: https://arxiv.org/abs/2004.10568
+[6 ]: https://arxiv.org/abs/2201.11113
diff --git a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py
index 9acafbf58..9a97b4794 100644
--- a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py
+++ b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py
@@ -24,6 +24,7 @@
from brevitas.graph.target.flexml import preprocess_for_flexml_quantize
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_bias_correction
+from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gpfq
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gptq
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_learned_round_learning
from brevitas_examples.imagenet_classification.ptq.ptq_common import calibrate
@@ -52,6 +53,7 @@
'bias_bit_width': [32, 16], # Bias Bit-Width for Po2 scale
'weight_quant_granularity': ['per_tensor', 'per_channel'], # Scaling Per Output Channel
'act_quant_type': ['asym', 'sym'], # Act Quant Type
+ 'weight_param_method': ['stats', 'mse'], # Weight Quant Type
'act_param_method': ['stats', 'mse'], # Act Param Method
'bias_corr': [True], # Bias Correction
'graph_eq_iterations': [0, 20], # Graph Equalization
@@ -60,6 +62,8 @@
'learned_round': [False, True], # Enable/Disable Learned Round
'gptq': [False, True], # Enable/Disable GPTQ
'gptq_act_order': [False, True], # Use act_order euristics for GPTQ
+ 'gpfq': [False, True], # Enable/Disable GPFQ
+ 'gpfq_p': [0.25, 0.75], # GPFQ P
'act_quant_percentile': [99.9, 99.99, 99.999], # Activation Quantization Percentile
}
@@ -71,13 +75,16 @@
'bias_bit_width': [32], # Bias Bit-Width for Po2 scale
'weight_quant_granularity': ['per_channel'], # Scaling Per Output Channel
'act_quant_type': ['sym'], # Act Quant Type
- 'act_param_method': ['stats'], # Act Param Method
+ 'act_param_method': ['mse'], # Act Param Method
+ 'weight_param_method': ['stats'], # Weight Quant Type
'bias_corr': [True], # Bias Correction
'graph_eq_iterations': [20], # Graph Equalization
'graph_eq_merge_bias': [True], # Merge bias for Graph Equalization
'act_equalization': [None], # Perform Activation Equalization (Smoothquant)
'learned_round': [False], # Enable/Disable Learned Round
'gptq': [True], # Enable/Disable GPTQ
+ 'gpfq': [False], # Enable/Disable GPFQ
+ 'gpfq_p': [0.25], # GPFQ P
'gptq_act_order': [False], # Use act_order euristics for GPTQ
'act_quant_percentile': [99.999], # Activation Quantization Percentile
}
@@ -114,35 +121,36 @@ def main():
args.gpu = get_gpu_index(args.idx)
print("Iter {}, GPU {}".format(args.idx, args.gpu))
-
- options_names = [k.replace('_', ' ').capitalize() for k in OPTIONS.keys()]
- torchvision_df = pd.DataFrame(
- columns=options_names + [
- 'Top 1% floating point accuracy',
- 'Top 1% quant accuracy',
- 'Floating point accuracy - quant accuracy',
- 'Quant accuracy / floating point accuracy',
- 'Calibration size',
- 'Calibration batch size',
- 'Torch version',
- 'Brevitas version'])
try:
- ptq_torchvision_models(torchvision_df, args)
+ ptq_torchvision_models(args)
except Exception as E:
print("Exception at index {}: {}".format(args.idx, E))
-def ptq_torchvision_models(df, args):
+def ptq_torchvision_models(args):
# Generate all possible combinations, including invalid ones
# Split stats and mse due to the act_quant_percentile value
- percentile_options = OPTIONS.copy()
- percentile_options['act_param_method'] = ['stats']
- mse_options = OPTIONS.copy()
- mse_options['act_param_method'] = ['mse']
- mse_options['act_quant_percentile'] = [None]
+
+ if 'stats' in OPTIONS['act_param_method']:
+ percentile_options = OPTIONS.copy()
+ percentile_options['act_param_method'] = ['stats']
+ else:
+ percentile_options = None
+
+ if 'mse' in OPTIONS['act_param_method']:
+ mse_options = OPTIONS.copy()
+ mse_options['act_param_method'] = ['mse']
+ mse_options['act_quant_percentile'] = [None]
+ else:
+ mse_options = None
+
+ # Combine MSE and Percentile combinations, if they are defined
+ combinations = []
+ if mse_options is not None:
+ combinations += list(product(*mse_options.values()))
+ if percentile_options is not None:
+ combinations += list(product(*percentile_options.values()))
# Combine the two sets of combinations
- combinations = list(product(*percentile_options.values())) + list(
- product(*mse_options.values()))
# Generate Namespace for each configuration
configs = [
SimpleNamespace(**{k: v
@@ -152,10 +160,12 @@ def ptq_torchvision_models(df, args):
configs = list(map(validate_config, configs))
# Drop invalid configurations
configs = list(config for config in configs if config.is_valid)
+
if args.idx > len(configs):
return
config_namespace = configs[args.idx]
+ print(config_namespace)
fp_accuracy = TORCHVISION_TOP1_MAP[config_namespace.model_name]
# Get model-specific configurations about input shapes and normalization
@@ -212,6 +222,8 @@ def ptq_torchvision_models(df, args):
backend=config_namespace.target_backend,
act_bit_width=config_namespace.act_bit_width,
weight_bit_width=config_namespace.weight_bit_width,
+ weight_param_method=config_namespace.weight_param_method,
+ act_param_method=config_namespace.act_param_method,
bias_bit_width=config_namespace.bias_bit_width,
weight_quant_granularity=config_namespace.weight_quant_granularity,
act_quant_percentile=config_namespace.act_quant_percentile,
@@ -228,6 +240,10 @@ def ptq_torchvision_models(df, args):
print("Starting calibration")
calibrate(calib_loader, quant_model)
+ if config_namespace.gpfq:
+ print("Performing GPFQ:")
+ apply_gpfq(calib_loader, quant_model, p=config_namespace.gpfq_p)
+
if config_namespace.gptq:
print("Performing gptq")
apply_gptq(calib_loader, quant_model, config_namespace.gptq_act_order)
@@ -292,6 +308,10 @@ def validate_config(config_namespace):
if not config_namespace.gptq and config_namespace.gptq_act_order:
is_valid = False
+ # If GPFQ is disabled, we execute only one configuration for p==0.25
+ if not config_namespace.gpfq and config_namespace.gpfq_p == 0.75:
+ is_valid = False
+
if config_namespace.act_equalization == 'layerwise' and config_namespace.target_backend == 'fx':
is_valid = False
if config_namespace.act_bit_width < config_namespace.weight_bit_width:
diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py
index 7982059f8..f2ae5092c 100644
--- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py
+++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py
@@ -13,6 +13,7 @@
from brevitas.graph.calibrate import calibration_mode
from brevitas.graph.calibrate import norm_correction_mode
from brevitas.graph.equalize import activation_equalization_mode
+from brevitas.graph.gpfq import gpfq_mode
from brevitas.graph.gptq import gptq_mode
from brevitas.graph.quantize import layerwise_quantize
from brevitas.graph.quantize import quantize
@@ -203,7 +204,7 @@ def kwargs_prefix(prefix, weight_kwargs):
weight_quant = weight_quant.let(zero_point_impl=ParameterFromStatsFromParameterZeroPoint)
if act_quant is not None:
act_quant = act_quant.let(**{'high_percentile_q': act_quant_percentile, 'dtype': dtype})
- if act_quant_type == 'asym':
+ if act_quant_type == 'asym' and act_quant_percentile is not None:
act_quant = act_quant.let(**{'low_percentile_q': 100 - act_quant_percentile})
if sym_act_quant is not None:
sym_act_quant = sym_act_quant.let(
@@ -213,7 +214,7 @@ def kwargs_prefix(prefix, weight_kwargs):
per_tensor_act_quant = per_tensor_act_quant.let(
**{
'high_percentile_q': act_quant_percentile, 'dtype': dtype})
- if act_quant_type == 'asym':
+ if act_quant_type == 'asym' and act_quant_percentile is not None:
per_tensor_act_quant = per_tensor_act_quant.let(
**{'low_percentile_q': 100 - act_quant_percentile})
@@ -360,6 +361,21 @@ def apply_gptq(calib_loader, model, act_order=False):
gptq.update()
+def apply_gpfq(calib_loader, model, p=0.25):
+ model.eval()
+ dtype = next(model.parameters()).dtype
+ device = next(model.parameters()).device
+ with torch.no_grad():
+ with gpfq_mode(model, p=p, use_quant_activations=True) as gpfq:
+ gpfq_model = gpfq.model
+ for i in tqdm(range(gpfq.num_layers)):
+ for i, (images, target) in enumerate(calib_loader):
+ images = images.to(device)
+ images = images.to(dtype)
+ gpfq_model(images)
+ gpfq.update()
+
+
def apply_learned_round_learning(
model, dataloader, optimizer_class=torch.optim.Adam, iters=1000, optimizer_lr=1e-1):
layers = []
diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py
index fdf2b966c..a560cd4c3 100644
--- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py
+++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py
@@ -22,6 +22,7 @@
from brevitas.graph.target.flexml import preprocess_for_flexml_quantize
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_bias_correction
+from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gpfq
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gptq
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_learned_round_learning
from brevitas_examples.imagenet_classification.ptq.ptq_common import calibrate
@@ -165,7 +166,10 @@
'weight-narrow-range',
default=True,
help='Narrow range for weight quantization (default: enabled)')
+parser.add_argument(
+ '--gpfq-p', default=0.25, type=float, help='P parameter for GPFQ (default: 0.25)')
add_bool_arg(parser, 'gptq', default=True, help='GPTQ (default: enabled)')
+add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)')
add_bool_arg(
parser, 'gptq-act-order', default=False, help='GPTQ Act order heuristic (default: disabled)')
add_bool_arg(parser, 'learned-round', default=False, help='Learned round (default: disabled)')
@@ -191,6 +195,7 @@ def main():
f"a{args.act_bit_width}"
f"w{args.weight_bit_width}_"
f"{'gptq_' if args.gptq else ''}"
+ f"{'gpfq_' if args.gpfq else ''}"
f"{'gptq_act_order_' if args.gptq_act_order else ''}"
f"{'learned_round_' if args.learned_round else ''}"
f"{'weight_narrow_range_' if args.weight_narrow_range else ''}"
@@ -211,6 +216,8 @@ def main():
f"Activation bit width: {args.act_bit_width} - "
f"Weight bit width: {args.weight_bit_width} - "
f"GPTQ: {args.gptq} - "
+ f"GPFQ: {args.gpfq} - "
+ f"GPFQ P: {args.gpfq_p} - "
f"GPTQ Act Order: {args.gptq_act_order} - "
f"Learned Round: {args.learned_round} - "
f"Weight narrow range: {args.weight_narrow_range} - "
@@ -299,9 +306,13 @@ def main():
print("Starting activation calibration:")
calibrate(calib_loader, quant_model)
+ if args.gpfq:
+ print("Performing GPFQ:")
+ apply_gpfq(calib_loader, quant_model, p=args.gpfq_p)
+
if args.gptq:
print("Performing GPTQ:")
- apply_gptq(calib_loader, quant_model, args.gptq_act_order)
+ apply_gptq(calib_loader, quant_model, act_order=args.gptq_act_order)
if args.learned_round:
print("Applying Learned Round:")