Skip to content

Commit

Permalink
fixes act_order for group-wise convolutions
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Nov 15, 2023
1 parent a4ea065 commit c8cc838
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,22 +234,25 @@ def single_layer_update(self):
for t in range(weight.shape[-1]):
for group_index in range(self.groups):
U[group_index] += torch.matmul(
weight[group_index, :, perm[t]].unsqueeze(1),
self.float_input[group_index, :,
perm[t]].unsqueeze(0)) #[OC/Groups, 1] * [1, INSHAPE[1]]
norm = torch.linalg.norm(self.quantized_input[group_index, :, perm[t]], 2) ** 2
weight[group_index, :, permutation_list[group_index][t]].unsqueeze(1),
self.float_input[group_index, :, permutation_list[group_index][t]].unsqueeze(
0)) #[OC/Groups, 1] * [1, INSHAPE[1]]
norm = torch.linalg.norm(
self.quantized_input[group_index, :, permutation_list[group_index][t]], 2) ** 2
if norm > 0:
q_arg = U[group_index].matmul(
self.quantized_input[group_index, :, perm[t]]) / norm
self.quantized_input[group_index, :,
permutation_list[group_index][t]]) / norm
else:
q_arg = torch.zeros_like(U[group_index, :, 0])

weight[group_index, :, perm[t]] = q_arg
weight[group_index, :, permutation_list[group_index][t]] = q_arg
q = self.get_quant_weights(t, 0, permutation_list)
for group_index in range(self.groups):
U[group_index] -= torch.matmul(
q[group_index].unsqueeze(1),
self.quantized_input[group_index, :, perm[t]].unsqueeze(0))
self.quantized_input[group_index, :,
permutation_list[group_index][t]].unsqueeze(0))

del self.float_input
del self.quantized_input

0 comments on commit c8cc838

Please sign in to comment.