Skip to content

Commit

Permalink
Pre-commit fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert committed Oct 22, 2024
1 parent b17dcc9 commit a429705
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 16 deletions.
14 changes: 11 additions & 3 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,11 @@ def single_layer_update(self):
self.float_input = self.float_input.to(dev)
self.quant_input = self.quant_input.to(dev)
U = torch.zeros(
weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=torch.float32)
weight.shape[0],
weight.shape[1],
self.float_input.shape[1],
device=dev,
dtype=torch.float32)
# We don't need full Hessian, we just need the diagonal
# Summing over batch dimension
H_diag = self.quant_input.transpose(2, 1).square().sum(2)
Expand All @@ -259,7 +263,8 @@ def single_layer_update(self):
for t in range(weight.shape[-1]):
for group_index in range(self.groups):
U[group_index] += torch.matmul(
weight[group_index, :, permutation_list[group_index][t]].unsqueeze(1).to(torch.float32),
weight[group_index, :,
permutation_list[group_index][t]].unsqueeze(1).to(torch.float32),
self.float_input[group_index, :, permutation_list[group_index][t]].unsqueeze(
0)) #[OC/Groups, 1] * [1, INSHAPE[1]]
norm = torch.linalg.norm(
Expand Down Expand Up @@ -446,7 +451,10 @@ def single_layer_update(self, percdamp: float = 0.01):
permutation_list = self._get_permutation_list(weight)

U = torch.zeros(
weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev,
weight.shape[0],
weight.shape[1],
self.float_input.shape[1],
device=dev,
dtype=torch.float32) # [Groups, OC/groups, Samples]

for t in range(weight.shape[-1]):
Expand Down
35 changes: 22 additions & 13 deletions src/brevitas_examples/common/axe.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,20 @@ def single_layer_update(self, percdamp=0.01):
# translating into the quantized range; need to pad to get these thresholds
wT = pad_tensor_with_zeros(weight / scales, self.max_accumulator_tile_size).view(
-1, self.max_accumulator_tile_size) # [OC * Tiles, IC / Tiles]
thresholds = calc_average_nonzero_mag(wT - wT.mean(axis=1, keepdim=True), Z) # [Groups * OC * Tiles]
thresholds = calc_average_nonzero_mag(
wT - wT.mean(axis=1, keepdim=True), Z) # [Groups * OC * Tiles]
thresholds = thresholds.view(self.groups, n_tiles, -1) # [Groups, Tiles, OC/Groups]
del wT
# supporting groupwise quantization where each tile has its own scaling factor
if self.layer.weight_quant.is_groupwise:
scales = pad_tensor_with_zeros(scales, self.max_accumulator_tile_size).view(-1, self.max_accumulator_tile_size) # [Groups, OC * Tiles, IC / Tiles]
scales = scales[:,0] # [Groups * OC * Tiles, 1]
scales = scales.view(self.groups, -1, n_tiles).transpose(1,2) # [Groups, Tiles, OC/Groups]
scales = pad_tensor_with_zeros(scales, self.max_accumulator_tile_size).view(
-1, self.max_accumulator_tile_size) # [Groups, OC * Tiles, IC / Tiles]
scales = scales[:, 0] # [Groups * OC * Tiles, 1]
scales = scales.view(self.groups, -1,
n_tiles).transpose(1, 2) # [Groups, Tiles, OC/Groups]
# else each tile has the same scaling factor (per-tensor or per-channel)
else:
scales = scales.view(self.groups, 1 , -1) # [Groups, 1, OC/Groups]
scales = scales.view(self.groups, 1, -1) # [Groups, 1, OC/Groups]
scales = scales.repeat(1, n_tiles, 1) # [Groups, Tiles, OC/Groups]
thresholds *= scales # translating centers back to the float range
weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC]
Expand Down Expand Up @@ -312,17 +315,21 @@ def single_layer_update(self, percdamp=0.01):
# translating into the quantized range; need to pad to get these thresholds
wT = pad_tensor_with_zeros(weight / scales, self.max_accumulator_tile_size).view(
-1, self.max_accumulator_tile_size) # [OC * Tiles, IC / Tiles]
thresholds = calc_average_nonzero_mag(wT - wT.mean(axis=1, keepdim=True), Z) # [Groups * OC * Tiles]
thresholds = thresholds.view(self.groups, -1, n_tiles).transpose(1,2) # [Groups, Tiles, OC/Groups]
thresholds = calc_average_nonzero_mag(
wT - wT.mean(axis=1, keepdim=True), Z) # [Groups * OC * Tiles]
thresholds = thresholds.view(self.groups, -1,
n_tiles).transpose(1, 2) # [Groups, Tiles, OC/Groups]
del wT
# supporting groupwise quantization where each tile has its own scaling factor
if self.layer.weight_quant.is_groupwise:
scales = pad_tensor_with_zeros(scales, self.max_accumulator_tile_size).view(-1, self.max_accumulator_tile_size) # [Groups, OC * Tiles, IC / Tiles]
scales = scales[:,0] # [Groups * OC * Tiles, 1]
scales = scales.view(self.groups, -1, n_tiles).transpose(1,2) # [Groups, Tiles, OC/Groups]
scales = pad_tensor_with_zeros(scales, self.max_accumulator_tile_size).view(
-1, self.max_accumulator_tile_size) # [Groups, OC * Tiles, IC / Tiles]
scales = scales[:, 0] # [Groups * OC * Tiles, 1]
scales = scales.view(self.groups, -1,
n_tiles).transpose(1, 2) # [Groups, Tiles, OC/Groups]
# else each tile has the same scaling factor (per-tensor or per-channel)
else:
scales = scales.view(self.groups, 1 , -1) # [Groups, 1, OC/Groups]
scales = scales.view(self.groups, 1, -1) # [Groups, 1, OC/Groups]
scales = scales.repeat(1, n_tiles, 1) # [Groups, Tiles, OC/Groups]
thresholds *= scales # translating centers back to the float range

Expand Down Expand Up @@ -398,8 +405,10 @@ def single_layer_update(self, percdamp=0.01):
q_arg.abs() - thresholds[group_index, bx, :]) # soft thresholding

# TODO: assuming round to nearest; need to generally support other rounding
q_max = scales[group_index, bx] * torch.clamp_min(self.upper_lim - a[group_index, bx, :] - 0.5, 0.0)
q_min = scales[group_index, bx] * torch.clamp_max(self.lower_lim - b[group_index, bx, :] + 0.5, 0.0)
q_max = scales[group_index, bx] * torch.clamp_min(
self.upper_lim - a[group_index, bx, :] - 0.5, 0.0)
q_min = scales[group_index, bx] * torch.clamp_max(
self.lower_lim - b[group_index, bx, :] + 0.5, 0.0)
q_arg.clamp_(q_min, q_max)
weight[group_index, :, i] = q_arg.to(dtype)
q_groups: Tensor = self.get_quant_weights(t, 0, permutation_list)
Expand Down

0 comments on commit a429705

Please sign in to comment.