Skip to content

Commit

Permalink
Feat (axe): adding support for per-group quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert committed Oct 22, 2024
1 parent 6b7dd28 commit c2f0ba5
Showing 1 changed file with 86 additions and 56 deletions.
142 changes: 86 additions & 56 deletions src/brevitas_examples/common/axe.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,24 +107,34 @@ def single_layer_update(self, percdamp=0.01):
raise NotImplementedError("Signed inputs not yet supported.")

n_tiles = math.ceil(weight.shape[-1] / self.max_accumulator_tile_size)

s = self.layer.weight_quant.scale()
scales: Tensor = self.layer.weight_quant.scale()
if isinstance(self.layer, SUPPORTED_CONV_OP):
if isinstance(self.layer, SUPPORTED_TCONV_OP):
scales = scales.transpose(1, 0) # This performs a view
scales = scales.flatten(1)
P = torch.tensor(self.max_accumulator_bit_width)
N = self.quant_metadata.bit_width
# NOTE: using sign-magnitude here, which is sufficient to support both
# sign-magnitude and 2s complement accumulators
A = (pow(2, P - 1) - 1) / float(pow(2, N) - 1)
B = -A
Z = (pow(2, P) - 2) / float(pow(2, N) - 1)
self.upper_lim = (pow(2, P - 1) - 1) / float(pow(2, N) - 1) # A
self.lower_lim = -self.upper_lim # B
Z = (pow(2, P) - 2) / float(pow(2, N) - 1) # l1-norm lim for zero-centered weight vector
# translating into the quantized range; need to pad to get these thresholds
wT = pad_tensor_with_zeros(weight / s, self.max_accumulator_tile_size).view(
wT = pad_tensor_with_zeros(weight / scales, self.max_accumulator_tile_size).view(
-1, self.max_accumulator_tile_size) # [OC * Tiles, IC / Tiles]
T = calc_average_nonzero_mag(
wT - wT.mean(axis=1, keepdim=True), Z) # [OC * Tiles, IC / Tiles]
T = T.view(self.groups, n_tiles, -1) # [Groups, Tiles, OC/Groups]
s = s.view(self.groups, -1) # [Groups, OC/Groups]
T *= s # translating centers back to the float range

thresholds = calc_average_nonzero_mag(wT - wT.mean(axis=1, keepdim=True), Z) # [Groups * OC * Tiles]
thresholds = thresholds.view(self.groups, n_tiles, -1) # [Groups, Tiles, OC/Groups]
del wT
# supporting groupwise quantization where each tile has its own scaling factor
if self.layer.weight_quant.is_groupwise:
scales = pad_tensor_with_zeros(scales, self.max_accumulator_tile_size).view(-1, self.max_accumulator_tile_size) # [Groups, OC * Tiles, IC / Tiles]
scales = scales[:,0] # [Groups * OC * Tiles, 1]
scales = scales.view(self.groups, -1, n_tiles).transpose(1,2) # [Groups, Tiles, OC/Groups]
# else each tile has the same scaling factor (per-tensor or per-channel)
else:
scales = scales.view(self.groups, 1 , -1) # [Groups, 1, OC/Groups]
scales = scales.repeat(1, n_tiles, 1) # [Groups, Tiles, OC/Groups]
thresholds *= scales # translating centers back to the float range
weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC]

# List with permutation tensors for the Hessian and weight matrix.
Expand Down Expand Up @@ -173,16 +183,15 @@ def single_layer_update(self, percdamp=0.01):
del self.H, self.B

# initialize cumulative l1-norm
a = torch.zeros_like(T, device=dev) # pos
b = torch.zeros_like(T, device=dev) # neg
a = torch.zeros_like(thresholds, device=dev) # positive limits
b = torch.zeros_like(thresholds, device=dev) # negative limits

for i1 in range(0, self.columns, self.blocksize):
i2 = min(i1 + self.blocksize, self.columns)
count = i2 - i1
error_block = torch.zeros_like(
weight[:, :, permutation_list[-1][i1:i2]],
dtype=torch.float32,
) # [groups, OC/groups, i2-i1]
dtype=torch.float32) # [groups, OC/groups, i2-i1]

h_inv_block = h_inv[:, i1:i2, i1:i2]
for i in range(count):
Expand All @@ -192,14 +201,14 @@ def single_layer_update(self, percdamp=0.01):
bx = perm[i1:i2][i] // self.max_accumulator_tile_size # block index
# calculate the q_max and q_min for the right group and right block
# TODO: currently assuming round-to-zero; need to handle other rounding functions
q_max = s[group_index, :] * torch.clamp_min(
A - a[group_index, bx, :] - 0.5, 0.0) # [OC/groups]
q_min = s[group_index, :] * torch.clamp_max(
B - b[group_index, bx, :] + 0.5, 0.0) # [OC/groups]
q_max = scales[group_index, bx, :] * torch.clamp_min(
self.upper_lim - a[group_index, bx, :] - 0.5, 0.0) # [OC/groups]
q_min = scales[group_index, bx, :] * torch.clamp_max(
self.lower_lim - b[group_index, bx, :] + 0.5, 0.0) # [OC/groups]
q_arg = weight[group_index, :, perm[i1:i2][i]] # [OC/groups]
# soft thresholding then clamping
q_arg = q_arg.sign() * torch.relu(
q_arg.abs() - T[group_index, bx]) # [OC/groups]
q_arg.abs() - thresholds[group_index, bx]) # [OC/groups]
q_arg.clamp_(q_min, q_max) # clamping to bounds
weight[group_index, :, perm[i1:i2][i]] = q_arg
q_groups = self.get_quant_weights(i, i1, permutation_list) # [Groups, OC/groups]
Expand All @@ -219,12 +228,12 @@ def single_layer_update(self, percdamp=0.01):
for group_index in range(self.groups):
perm = permutation_list[group_index]
bx = perm[i1:i2][i] // self.max_accumulator_tile_size # block index
q = q_groups[group_index] / s[group_index] # [OC/groups]
q = q_groups[group_index] / scales[group_index, bx] # [OC/groups]
# increment cumulative l1-norm
a[group_index, bx, q >= 0] += q[q >= 0]
b[group_index, bx, q <= 0] += q[q <= 0]
assert (a <= A).all() and (a >= 0).all()
assert (b >= B).all() and (b <= 0).all()
assert (a <= self.upper_lim).all() and (a >= 0).all()
assert (b >= self.lower_lim).all() and (b <= 0).all()

for group_index in range(self.groups):
perm = permutation_list[group_index]
Expand All @@ -234,6 +243,8 @@ def single_layer_update(self, percdamp=0.01):
if hasattr(self.layer, "offload_params"):
self.layer.offload_params(self.layer)

del thresholds, scales # memory management


class A2GPFQ(GPFQv2):
"""
Expand Down Expand Up @@ -286,43 +297,62 @@ def single_layer_update(self, percdamp=0.01):
raise NotImplementedError("Signed inputs not yet supported.")

n_tiles = math.ceil(weight.shape[-1] / self.max_accumulator_tile_size)

s = self.layer.weight_quant.scale()
scales: Tensor = self.layer.weight_quant.scale()
if isinstance(self.layer, SUPPORTED_CONV_OP):
if isinstance(self.layer, SUPPORTED_TCONV_OP):
scales = scales.transpose(1, 0) # This performs a view
scales = scales.flatten(1)
P = torch.tensor(self.max_accumulator_bit_width)
N = self.quant_metadata.bit_width
# NOTE: using sign-magnitude here, which is sufficient to support both
# sign-magnitude and 2s complement accumulators
A = (pow(2, P - 1) - 1) / float(pow(2, N) - 1)
B = -A
Z = (pow(2, P) - 2) / float(pow(2, N) - 1)
self.upper_lim = (pow(2, P - 1) - 1) / float(pow(2, N) - 1) # A
self.lower_lim = -self.upper_lim # B
Z = (pow(2, P) - 2) / float(pow(2, N) - 1) # l1-norm lim for zero-centered weight vector
# translating into the quantized range; need to pad to get these thresholds
wT = pad_tensor_with_zeros(weight / s, self.max_accumulator_tile_size).view(
wT = pad_tensor_with_zeros(weight / scales, self.max_accumulator_tile_size).view(
-1, self.max_accumulator_tile_size) # [OC * Tiles, IC / Tiles]
T = calc_average_nonzero_mag(
wT - wT.mean(axis=1, keepdim=True), Z) # [OC * Tiles, IC / Tiles]
T = T.view(self.groups, n_tiles, -1) # [Groups, Tiles, OC/Groups]
s = s.view(self.groups, -1) # [Groups, OC/Groups]
T *= s # translating centers back to the float range
thresholds = calc_average_nonzero_mag(wT - wT.mean(axis=1, keepdim=True), Z) # [Groups * OC * Tiles]
thresholds = thresholds.view(self.groups, -1, n_tiles).transpose(1,2) # [Groups, Tiles, OC/Groups]
del wT
# supporting groupwise quantization where each tile has its own scaling factor
if self.layer.weight_quant.is_groupwise:
scales = pad_tensor_with_zeros(scales, self.max_accumulator_tile_size).view(-1, self.max_accumulator_tile_size) # [Groups, OC * Tiles, IC / Tiles]
scales = scales[:,0] # [Groups * OC * Tiles, 1]
scales = scales.view(self.groups, -1, n_tiles).transpose(1,2) # [Groups, Tiles, OC/Groups]
# else each tile has the same scaling factor (per-tensor or per-channel)
else:
scales = scales.view(self.groups, 1 , -1) # [Groups, 1, OC/Groups]
scales = scales.repeat(1, n_tiles, 1) # [Groups, Tiles, OC/Groups]
thresholds *= scales # translating centers back to the float range

weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC]

# initialize cumulative l1-norm
a = torch.zeros_like(T, device=dev) # pos
b = torch.zeros_like(T, device=dev) # neg
a = torch.zeros_like(thresholds, device=dev) # positive limit
b = torch.zeros_like(thresholds, device=dev) # negative limit

# stablize G with a dampening factor and then square root the matrix
norms = torch.zeros((self.groups, self.columns), device=dev, dtype=torch.float32)
self.H = self.H.to(dev)
diag = torch.arange(self.columns, device='cpu')
for i in range(self.groups):
damp = percdamp * self.H[i].diag().mean()
self.H[i, diag, diag] += damp
norms[i] = self.H[i].diag() # set the norms post-dampening
eigvals, eigvecs = torch.linalg.eigh(self.H[i])
eigvals.clamp_min_(0.0).sqrt_() # should be positive-definite
self.H[i] = eigvecs @ torch.diag(eigvals) @ eigvecs.t()
del eigvecs, eigvals, diag
self.quant_input = self.H # NOTE: do this here for the `get_permutation_list` function
# Try/Except in case the square root of H cannot be computed
try:
norms = torch.zeros((self.groups, self.columns), device=dev, dtype=torch.float32)
self.H = self.H.to(dev)
diag = torch.arange(self.columns, device='cpu')
for i in range(self.groups):
# stablize H with a dampening factor and then square root the matrix
damp = percdamp * self.H[i].diag().mean()
self.H[i, diag, diag] += damp
norms[i] = self.H[i].diag() # set the norms post-dampening
eigvals, eigvecs = torch.linalg.eigh(self.H[i])
eigvals.clamp_min_(0.0).sqrt_() # should be positive-definite
self.H[i] = eigvecs @ torch.diag(eigvals) @ eigvecs.t()
del eigvecs, eigvals, diag
self.quant_input = self.H # NOTE: do this here for the `get_permutation_list` function
except LinAlgError:
warnings.warn(
f'Failed to compute the matrix square root of H for layer {self.name} '
f'GPFQ will not be applied. '
f'Increasing the number of samples might fix this issue')
return

# Try/Except in case the inverse of H cannot be computed
try:
Expand Down Expand Up @@ -365,11 +395,11 @@ def single_layer_update(self, percdamp=0.01):
q_arg = torch.zeros_like(U[group_index, :, 0])
bx = i // self.max_accumulator_tile_size # block index
q_arg = q_arg.sign() * torch.relu(
q_arg.abs() - T[group_index, bx, :]) # soft thresholding
q_arg.abs() - thresholds[group_index, bx, :]) # soft thresholding

# TODO: assuming round to nearest; need to generally support other rounding
q_max = s[group_index] * torch.clamp_min(A - a[group_index, bx, :] - 0.5, 0.0)
q_min = s[group_index] * torch.clamp_max(B - b[group_index, bx, :] + 0.5, 0.0)
q_max = scales[group_index, bx] * torch.clamp_min(self.upper_lim - a[group_index, bx, :] - 0.5, 0.0)
q_min = scales[group_index, bx] * torch.clamp_max(self.lower_lim - b[group_index, bx, :] + 0.5, 0.0)
q_arg.clamp_(q_min, q_max)
weight[group_index, :, i] = q_arg.to(dtype)
q_groups: Tensor = self.get_quant_weights(t, 0, permutation_list)
Expand All @@ -379,11 +409,11 @@ def single_layer_update(self, percdamp=0.01):
q_groups[group_index].unsqueeze(1).to(torch.float32),
self.quant_input[group_index, :, i].unsqueeze(0))
bx = i // self.max_accumulator_tile_size # block index
q = q_groups[group_index] / s[group_index] # [OC/groups]
q = q_groups[group_index] / scales[group_index, bx] # [OC/groups]
# increment cumulative l1-norm
a[group_index, bx, q >= 0] += q[q >= 0]
b[group_index, bx, q <= 0] += q[q <= 0]
assert (a <= A).all() and (a >= 0).all()
assert (b >= B).all() and (b <= 0).all()
assert (a <= self.upper_lim).all() and (a >= 0).all()
assert (b >= self.lower_lim).all() and (b <= 0).all()

del self.quant_input, self.float_input

0 comments on commit c2f0ba5

Please sign in to comment.