Skip to content

Commit

Permalink
Feat (GPFA2Q): unify quant_input initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Dec 7, 2023
1 parent e5c6daa commit a3dad71
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,10 @@ def __init__(
self.accumulator_bit_width = accumulator_bit_width

def single_layer_update(self):
# raise error in case no quant-input is here
if self.quant_input is None:
raise ValueError(
'Expected quant input to calculate Upper Bound on L1 norm, but received None')
weight = self.layer.weight.data
dev = weight.device
dtype = weight.dtype
Expand All @@ -311,8 +315,8 @@ def single_layer_update(self):
self.quantized_input = self.quantized_input.to(dev)

# get upper bound
input_bit_width = self.layer.quant_input_bit_width()
input_is_signed = self.layer.is_quant_input_signed
input_bit_width = self.quant_input.bit_width
input_is_signed = self.quant_input.signed
T = get_upper_bound_on_l1_norm(self.accumulator_bit_width, input_bit_width, input_is_signed)
s = self.layer.quant_weight_scale()
s = s.view(self.groups, -1) # [Groups, OC/Groups]
Expand Down
8 changes: 8 additions & 0 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,14 @@ def process_input(self, inp):
signed=inp.signed,
training=inp.training)
inp = inp.value
elif self.layer.is_input_quant_enabled:
self.quant_input = QuantTensor(
value=None,
scale=self.layer.quant_input_scale(),
zero_point=self.layer.quant_input_zero_point(),
bit_width=self.layer.quant_input_bit_width(),
signed=self.layer.is_quant_input_signed,
training=self.layer.training)

# If input is unbatched, add batch_size = 1
if len(inp.shape) == 1:
Expand Down

0 comments on commit a3dad71

Please sign in to comment.