From 7e16d9c5a91459645aa38b1592e32b6b0fde40a1 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sun, 4 Feb 2024 00:55:16 +0000 Subject: [PATCH] Fix for output quant metadata --- src/brevitas/nn/quant_layer.py | 5 +++++ src/brevitas/nn/quant_rnn.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 9590f8e11..dcc9d6f58 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -296,6 +296,11 @@ def merge_bn_in(self, bn): merge_bn(self, bn, output_channel_dim=self.output_channel_dim) def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: + output_scale = None + output_bit_width = None + output_signed = None + output_zero_point = None + inp = self.unpack_input(inp) # shortcut execution through the export impl during export diff --git a/src/brevitas/nn/quant_rnn.py b/src/brevitas/nn/quant_rnn.py index 4b00dac20..63bd55e19 100644 --- a/src/brevitas/nn/quant_rnn.py +++ b/src/brevitas/nn/quant_rnn.py @@ -419,7 +419,7 @@ def forward(self, inp, state): quant_weight_ih, quant_weight_hh, quant_bias = self.gate_params_fwd( self.gate_params, quant_input) quant_input_value = _unpack_quant_tensor(quant_input) - if getattr(quant_bias, 'value', quant_bias) is None: + if quant_bias is None: quant_bias = torch.tensor(0., device=quant_input_value.device) else: quant_bias = _unpack_quant_tensor(quant_bias)