Skip to content

Commit

Permalink
Feat (GPFA2Q): add A2Q bound to GPFQ
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Dec 21, 2023
1 parent 2e6e179 commit adab5f6
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 29 deletions.
23 changes: 3 additions & 20 deletions src/brevitas/core/scaling/pre_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from brevitas.core.stats import SCALAR_SHAPE
from brevitas.core.stats.stats_wrapper import _Stats
from brevitas.function import abs_binary_sign_grad
from brevitas.function import get_upper_bound_on_l1_norm

__all__ = ["ParameterPreScalingWeightNorm", "AccumulatorAwareParameterPreScaling"]

Expand Down Expand Up @@ -170,33 +171,15 @@ def __init__(
)
self.accumulator_bit_width = accumulator_bit_width_impl

@brevitas.jit.script_method
def get_upper_bound_on_l1_norm(self, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
"""Calculate the upper bound on the l1-norm of the weights using the derivations from
`Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow Avoidance`
by I.Colbert, A.Pappalardo, and J.Petri-Koenig."""
assert input_bit_width is not None, "A2Q relies on input bit-width."
assert input_is_signed is not None, "A2Q relies on input sign."
input_is_signed = float(input_is_signed) # 1. if signed else 0.
# This is the minimum of the two maximum magnitudes that P could take, which are -2^{P-1}
# and 2^{P-1}-1. Note that evaluating to -2^{P-1} would mean there is a possibility of overflow
# on the positive side of this range.
max_accumulator_bit_width = self.accumulator_bit_width() # P
max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1
# This is the maximum possible magnitude that the input data could take. When the data is signed,
# this is 2^{N-1}. When the data is unsigned, this is 2^N - 1. We use a slightly looser bound here
# to simplify our derivations on the export validation.
max_input_mag_inverse = pow(2., input_is_signed - input_bit_width)
return max_accumulator_mag * max_input_mag_inverse

@brevitas.jit.script_method
def forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
"""Takes weights as input and returns the pre-clipping scaling factor"""
weights = self.stats_input_view_shape_impl(weights)
d_w = self.stats(weights) # denominator for weight normalization
s = self.scaling_impl(weights) # s
g = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) # g
T = self.get_upper_bound_on_l1_norm(input_bit_width, input_is_signed) # T / s
T = get_upper_bound_on_l1_norm(
self.accumulator_bit_width(), input_bit_width, input_is_signed) # T / s
g = torch.clamp_max(g / s, T)
value = d_w / g # calculating final pre-clipping scaling factor
# re-apply clamp_min_ste from restrict_scaling_impl to the specified pre_scaling_min_val
Expand Down
15 changes: 15 additions & 0 deletions src/brevitas/function/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,18 @@ def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_b
device=mantissa_bit_width.device)))
max_val = max_mantissa * (2 ** max_exponent)
return max_val


def get_upper_bound_on_l1_norm(
accumulator_bit_width: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
"""Calculate the upper bound on the l1-norm of the weights using the derivations from
`Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow Avoidance`
by I.Colbert, A.Pappalardo, and J.Petri-Koenig."""
assert input_bit_width is not None, "A2Q relies on input bit-width."
assert input_is_signed is not None, "A2Q relies on input sign."
assert accumulator_bit_width is not None, "A2Q relies on accumulator bit-width."
input_is_signed = float(input_is_signed) # 1. if signed else 0.
max_accumulator_bit_width = accumulator_bit_width # P
max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1
max_input_mag_inverse = pow(2., input_is_signed - input_bit_width)
return max_accumulator_mag * max_input_mag_inverse
133 changes: 125 additions & 8 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import unfoldNd

from brevitas.function import get_upper_bound_on_l1_norm
from brevitas.graph.gpxq import GPxQ
from brevitas.graph.gpxq import gpxq_mode
from brevitas.graph.gpxq import StopFwdException
Expand Down Expand Up @@ -45,7 +46,9 @@ def __init__(
use_quant_activations: bool = True,
p: float = 1.0,
return_forward_output: bool = False,
act_order: bool = False) -> None:
act_order: bool = False,
use_gpfa2q: bool = False,
accumulator_bit_width: Optional[int] = None) -> None:
if not inplace:
model = deepcopy(model)
super().__init__(
Expand All @@ -61,6 +64,10 @@ def __init__(
self.model.forward = self.catch_stopfwd
self.p = p

# GPFA2Q params
self.use_gpfa2q = use_gpfa2q
self.accumulator_bit_width = accumulator_bit_width

def catch_stopfwd(self, *args, **kwargs):
# Collect quant input
try:
Expand Down Expand Up @@ -96,13 +103,23 @@ def catch_stopfwd(self, *args, **kwargs):

def initialize_module_optimizer(
self, layer, name, act_order, len_parallel_layers, create_weight_orig):
return GPFQ(
layer=layer,
name=name,
act_order=act_order,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=self.p)
if not self.use_gpfa2q:
return GPFQ(
layer=layer,
name=name,
act_order=act_order,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=self.p)
else:
return GPFA2Q(
layer=layer,
name=name,
act_order=act_order,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=self.p,
accumulator_bit_width=self.accumulator_bit_width)


class GPFQ(GPxQ):
Expand Down Expand Up @@ -256,3 +273,103 @@ def single_layer_update(self):

del self.float_input
del self.quantized_input


class GPFA2Q(GPFQ):

def __init__(
self,
layer,
name,
act_order,
len_parallel_layers,
create_weight_orig,
accumulator_bit_width,
p) -> None:
GPFQ.__init__(
self,
layer=layer,
name=name,
act_order=act_order,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=p)
self.accumulator_bit_width = accumulator_bit_width
assert self.accumulator_bit_width is not None

def single_layer_update(self):
# raise error in case no quant-input is here
if self.quant_input is None:
raise ValueError(
'Expected quant input to calculate Upper Bound on L1 norm, but received None')
weight = self.layer.weight.data
dev = weight.device
dtype = weight.dtype
if isinstance(self.layer, SUPPORTED_CONV_OP):
if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
weight = weight.transpose(1, 0) # This performs a view
weight = weight.flatten(1)
weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC]
U = torch.zeros(
weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype)
self.float_input = self.float_input.to(dev)
self.quantized_input = self.quantized_input.to(dev)

# get upper bound
input_bit_width = self.quant_input.bit_width
input_is_signed = self.quant_input.signed
T = get_upper_bound_on_l1_norm(
torch.tensor(self.accumulator_bit_width), input_bit_width, input_is_signed)
s = self.layer.quant_weight_scale()
s = s.view(self.groups, -1) # [Groups, OC/Groups]

l1_norm = torch.zeros(weight.shape[:-1], device=dev)

# We don't need full Hessian, we just need the diagonal
self.H_diag = self.quantized_input.transpose(2, 1).square().sum(
2) # summing over Batch dimension
permutation_list = []
for group_index in range(self.groups):
if self.act_order:
# Re-order Hessian_diagonal so that weights associated to
# higher magnitude activations are quantized first
perm = torch.argsort(self.H_diag[group_index, :], descending=True)
else:
# No permutation, permutation tensor is a ordered index
perm = torch.tensor(range(weight.shape[-1]), device=dev)
permutation_list.append(perm)

for t in range(weight.shape[-1]):
for group_index in range(self.groups):
U[group_index] += torch.matmul(
weight[group_index, :, permutation_list[group_index][t]].unsqueeze(1),
self.float_input[group_index, :, permutation_list[group_index][t]].unsqueeze(
0)) #[OC/Groups, 1] * [1, INSHAPE[1]]
norm = torch.linalg.norm(
self.quantized_input[group_index, :, permutation_list[group_index][t]], 2) ** 2
if norm > 0:
q_arg = U[group_index].matmul(
self.quantized_input[group_index, :,
permutation_list[group_index][t]]) / norm
else:
q_arg = torch.zeros_like(U[group_index, :, 0])

weight[group_index, :, permutation_list[group_index][t]] = q_arg
q = self.get_quant_weights(t, 0, permutation_list)

for group_index in range(self.groups):
candidate_l1 = l1_norm[group_index] + torch.abs(q[group_index])
candidate_l1_mask = candidate_l1 > T * s[group_index]
if torch.any(candidate_l1_mask):
# set all values to 0 that are exceeding T * s
weight[group_index, :, permutation_list[group_index][t]][candidate_l1_mask] = 0
q[group_index][candidate_l1_mask] = 0
else:
l1_norm[group_index] = candidate_l1
U[group_index] -= torch.matmul(
q[group_index].unsqueeze(1),
self.quantized_input[group_index, :,
permutation_list[group_index][t]].unsqueeze(0))

del self.float_input
del self.quantized_input
14 changes: 13 additions & 1 deletion src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from typing import List, Optional, Set
import warnings

import torch

from brevitas.graph.calibrate import DisableEnableQuantization
import brevitas.nn as qnn
from brevitas.quant_tensor import QuantTensor
Expand Down Expand Up @@ -175,13 +177,23 @@ def process_input(self, inp):
if self.layer.weight_quant_requires_quant_input:
# Can minimize memory allocation by not storing actual values
self.quant_input = QuantTensor(
value=None,
value=torch.empty(
1, dtype=self.layer.weight.dtype, device=self.layer.weight.device),
scale=inp.scale,
zero_point=inp.zero_point,
bit_width=inp.bit_width,
signed=inp.signed,
training=inp.training)
inp = inp.value
elif self.layer.is_input_quant_enabled:
self.quant_input = QuantTensor(
value=torch.empty(
1, dtype=self.layer.weight.dtype, device=self.layer.weight.device),
scale=self.layer.quant_input_scale(),
zero_point=self.layer.quant_input_zero_point(),
bit_width=self.layer.quant_input_bit_width(),
signed=self.layer.is_quant_input_signed,
training=self.layer.training)

# If input is unbatched, add batch_size = 1
if len(inp.shape) == 1:
Expand Down

0 comments on commit adab5f6

Please sign in to comment.