diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index c38a15712..b85ac1188 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -266,7 +266,7 @@ def get_quant_weights(self, i, i1, permutation_list): subtensor_slice_list = [None, (index, index + 1)] q = self.layer.quant_weight( subtensor_slice_list=subtensor_slice_list, - quant_input=self.quant_input).value.unsqueeze(0) # [1, OC, 1] + quant_input=self.quant_metadata).value.unsqueeze(0) # [1, OC, 1] elif isinstance(self.layer, SUPPORTED_CONV_OP): # For depthwise and ConvTranspose we fall back to quantizing the entire martix. # For all other cases, we create a mask that represent the slicing we will perform on the weight matrix @@ -274,7 +274,7 @@ def get_quant_weights(self, i, i1, permutation_list): if self.groups > 1 or (self.groups == 1 and isinstance( self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d))): - quant_weight = self.layer.quant_weight(quant_input=self.quant_input) + quant_weight = self.layer.quant_weight(quant_input=self.quant_metadata) quant_weight = quant_weight.value if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): @@ -299,7 +299,7 @@ def get_quant_weights(self, i, i1, permutation_list): index_2d_to_nd.insert(0, None) q = self.layer.quant_weight( subtensor_slice_list=index_2d_to_nd, - quant_input=self.quant_input).value.flatten(1) # [OC, 1] + quant_input=self.quant_metadata).value.flatten(1) # [OC, 1] q = q.unsqueeze(0) # [1, OC, 1] # We need to remove the last dim q = q.squeeze(2) # [groups, OC/groups] or [1, OC]