Skip to content

Commit

Permalink
Fix (gpxq): correct variable name (#944)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Apr 29, 2024
1 parent cae4004 commit 66f28b2
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,15 +266,15 @@ def get_quant_weights(self, i, i1, permutation_list):
subtensor_slice_list = [None, (index, index + 1)]
q = self.layer.quant_weight(
subtensor_slice_list=subtensor_slice_list,
quant_input=self.quant_input).value.unsqueeze(0) # [1, OC, 1]
quant_input=self.quant_metadata).value.unsqueeze(0) # [1, OC, 1]
elif isinstance(self.layer, SUPPORTED_CONV_OP):
# For depthwise and ConvTranspose we fall back to quantizing the entire martix.
# For all other cases, we create a mask that represent the slicing we will perform on the weight matrix
# and we quantize only the selected dimensions.
if self.groups > 1 or (self.groups == 1 and isinstance(
self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d))):

quant_weight = self.layer.quant_weight(quant_input=self.quant_input)
quant_weight = self.layer.quant_weight(quant_input=self.quant_metadata)
quant_weight = quant_weight.value

if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
Expand All @@ -299,7 +299,7 @@ def get_quant_weights(self, i, i1, permutation_list):
index_2d_to_nd.insert(0, None)
q = self.layer.quant_weight(
subtensor_slice_list=index_2d_to_nd,
quant_input=self.quant_input).value.flatten(1) # [OC, 1]
quant_input=self.quant_metadata).value.flatten(1) # [OC, 1]
q = q.unsqueeze(0) # [1, OC, 1]
# We need to remove the last dim
q = q.squeeze(2) # [groups, OC/groups] or [1, OC]
Expand Down

0 comments on commit 66f28b2

Please sign in to comment.