diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index c1826bb6d..9a32eb6b5 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -241,7 +241,11 @@ def single_layer_update(self): self.float_input = self.float_input.to(dev) self.quant_input = self.quant_input.to(dev) U = torch.zeros( - weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=torch.float32) + weight.shape[0], + weight.shape[1], + self.float_input.shape[1], + device=dev, + dtype=torch.float32) # We don't need full Hessian, we just need the diagonal # Summing over batch dimension H_diag = self.quant_input.transpose(2, 1).square().sum(2) @@ -259,7 +263,8 @@ def single_layer_update(self): for t in range(weight.shape[-1]): for group_index in range(self.groups): U[group_index] += torch.matmul( - weight[group_index, :, permutation_list[group_index][t]].unsqueeze(1).to(torch.float32), + weight[group_index, :, + permutation_list[group_index][t]].unsqueeze(1).to(torch.float32), self.float_input[group_index, :, permutation_list[group_index][t]].unsqueeze( 0)) #[OC/Groups, 1] * [1, INSHAPE[1]] norm = torch.linalg.norm( @@ -446,7 +451,10 @@ def single_layer_update(self, percdamp: float = 0.01): permutation_list = self._get_permutation_list(weight) U = torch.zeros( - weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, + weight.shape[0], + weight.shape[1], + self.float_input.shape[1], + device=dev, dtype=torch.float32) # [Groups, OC/groups, Samples] for t in range(weight.shape[-1]): diff --git a/src/brevitas_examples/common/axe.py b/src/brevitas_examples/common/axe.py index 05bc99c22..ff7569e53 100644 --- a/src/brevitas_examples/common/axe.py +++ b/src/brevitas_examples/common/axe.py @@ -122,17 +122,20 @@ def single_layer_update(self, percdamp=0.01): # translating into the quantized range; need to pad to get these thresholds wT = pad_tensor_with_zeros(weight / scales, self.max_accumulator_tile_size).view( -1, self.max_accumulator_tile_size) # [OC * Tiles, IC / Tiles] - thresholds = calc_average_nonzero_mag(wT - wT.mean(axis=1, keepdim=True), Z) # [Groups * OC * Tiles] + thresholds = calc_average_nonzero_mag( + wT - wT.mean(axis=1, keepdim=True), Z) # [Groups * OC * Tiles] thresholds = thresholds.view(self.groups, n_tiles, -1) # [Groups, Tiles, OC/Groups] del wT # supporting groupwise quantization where each tile has its own scaling factor if self.layer.weight_quant.is_groupwise: - scales = pad_tensor_with_zeros(scales, self.max_accumulator_tile_size).view(-1, self.max_accumulator_tile_size) # [Groups, OC * Tiles, IC / Tiles] - scales = scales[:,0] # [Groups * OC * Tiles, 1] - scales = scales.view(self.groups, -1, n_tiles).transpose(1,2) # [Groups, Tiles, OC/Groups] + scales = pad_tensor_with_zeros(scales, self.max_accumulator_tile_size).view( + -1, self.max_accumulator_tile_size) # [Groups, OC * Tiles, IC / Tiles] + scales = scales[:, 0] # [Groups * OC * Tiles, 1] + scales = scales.view(self.groups, -1, + n_tiles).transpose(1, 2) # [Groups, Tiles, OC/Groups] # else each tile has the same scaling factor (per-tensor or per-channel) else: - scales = scales.view(self.groups, 1 , -1) # [Groups, 1, OC/Groups] + scales = scales.view(self.groups, 1, -1) # [Groups, 1, OC/Groups] scales = scales.repeat(1, n_tiles, 1) # [Groups, Tiles, OC/Groups] thresholds *= scales # translating centers back to the float range weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] @@ -312,17 +315,21 @@ def single_layer_update(self, percdamp=0.01): # translating into the quantized range; need to pad to get these thresholds wT = pad_tensor_with_zeros(weight / scales, self.max_accumulator_tile_size).view( -1, self.max_accumulator_tile_size) # [OC * Tiles, IC / Tiles] - thresholds = calc_average_nonzero_mag(wT - wT.mean(axis=1, keepdim=True), Z) # [Groups * OC * Tiles] - thresholds = thresholds.view(self.groups, -1, n_tiles).transpose(1,2) # [Groups, Tiles, OC/Groups] + thresholds = calc_average_nonzero_mag( + wT - wT.mean(axis=1, keepdim=True), Z) # [Groups * OC * Tiles] + thresholds = thresholds.view(self.groups, -1, + n_tiles).transpose(1, 2) # [Groups, Tiles, OC/Groups] del wT # supporting groupwise quantization where each tile has its own scaling factor if self.layer.weight_quant.is_groupwise: - scales = pad_tensor_with_zeros(scales, self.max_accumulator_tile_size).view(-1, self.max_accumulator_tile_size) # [Groups, OC * Tiles, IC / Tiles] - scales = scales[:,0] # [Groups * OC * Tiles, 1] - scales = scales.view(self.groups, -1, n_tiles).transpose(1,2) # [Groups, Tiles, OC/Groups] + scales = pad_tensor_with_zeros(scales, self.max_accumulator_tile_size).view( + -1, self.max_accumulator_tile_size) # [Groups, OC * Tiles, IC / Tiles] + scales = scales[:, 0] # [Groups * OC * Tiles, 1] + scales = scales.view(self.groups, -1, + n_tiles).transpose(1, 2) # [Groups, Tiles, OC/Groups] # else each tile has the same scaling factor (per-tensor or per-channel) else: - scales = scales.view(self.groups, 1 , -1) # [Groups, 1, OC/Groups] + scales = scales.view(self.groups, 1, -1) # [Groups, 1, OC/Groups] scales = scales.repeat(1, n_tiles, 1) # [Groups, Tiles, OC/Groups] thresholds *= scales # translating centers back to the float range @@ -398,8 +405,10 @@ def single_layer_update(self, percdamp=0.01): q_arg.abs() - thresholds[group_index, bx, :]) # soft thresholding # TODO: assuming round to nearest; need to generally support other rounding - q_max = scales[group_index, bx] * torch.clamp_min(self.upper_lim - a[group_index, bx, :] - 0.5, 0.0) - q_min = scales[group_index, bx] * torch.clamp_max(self.lower_lim - b[group_index, bx, :] + 0.5, 0.0) + q_max = scales[group_index, bx] * torch.clamp_min( + self.upper_lim - a[group_index, bx, :] - 0.5, 0.0) + q_min = scales[group_index, bx] * torch.clamp_max( + self.lower_lim - b[group_index, bx, :] + 0.5, 0.0) q_arg.clamp_(q_min, q_max) weight[group_index, :, i] = q_arg.to(dtype) q_groups: Tensor = self.get_quant_weights(t, 0, permutation_list)