diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 6ad06c4c2..7208aa8e3 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -310,12 +310,17 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe quant_input = self.input_quant(inp) quant_weight = self.quant_weight(quant_input) - if quant_input.bit_width is not None and quant_weight.bit_width is not None: - output_bit_width = self.max_acc_bit_width(quant_input.bit_width, quant_weight.bit_width) - if quant_input.scale is not None and quant_weight.scale is not None: - output_scale = self.quant_output_scale_impl(inp, quant_input.scale, quant_weight.scale) - if quant_input.signed is not None: - output_signed = inp.signed or quant_weight.signed + if (self.return_quant_tensor or + (self.is_bias_quant_enabled and + (self.bias_quant.requires_input_scale or self.bias_quant.requires_input_bit_width))): + if quant_input.bit_width is not None and quant_weight.bit_width is not None: + output_bit_width = self.max_acc_bit_width( + quant_input.bit_width, quant_weight.bit_width) + if quant_input.scale is not None and quant_weight.scale is not None: + output_scale = self.quant_output_scale_impl( + inp, quant_input.scale, quant_weight.scale) + if quant_input.signed is not None: + output_signed = inp.signed or quant_weight.signed if self.bias is not None: quant_bias = self.bias_quant(self.bias, output_scale, output_bit_width)