Skip to content

Commit

Permalink
QuantTensor checks
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 2, 2023
1 parent 96adc87 commit 04944b5
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,21 +312,25 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
return out

quant_input = self.input_quant(inp)
# quant_input_value = getattr(quant_input, 'value', quant_input)
# quant_input_scale = getattr(quant_input, 'scale', None)
# quant_input_bitwidth = getattr(quant_input, 'bit_width', None)

quant_weight = self.quant_weight(quant_input)
# quant_weight_value = getattr(quant_weight, 'value', quant_weight)
# quant_weight_scale = getattr(quant_weight, 'scale', None)
# quant_weight_bitwidth = getattr(quant_weight, 'bit_width', None)

if (self.return_quant_tensor or
(self.is_bias_quant_enabled and
(self.bias_quant.requires_input_scale or self.bias_quant.requires_input_bit_width))):
if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor):
output_bit_width = self.max_acc_bit_width(
quant_input.bit_width, quant_weight.bit_width)
if hasattr(
quant_input,
'scale') and quant_input.scale is not None and quant_weight.scale is not None:

output_scale = self.quant_output_scale_impl(
inp, quant_input.scale, quant_weight.scale)
if hasattr(quant_input, 'signed') and quant_input.signed is not None:

quant_input_signed = quant_input.signed if isinstance(
quant_input, QuantTensor) else True
quant_weight_signed = quant_weight.signed if isinstance(
Expand Down

0 comments on commit 04944b5

Please sign in to comment.