From f106c7b2e0f4c0e01f372aaa2ae7455f1a79d989 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 19 Feb 2024 14:29:36 +0000 Subject: [PATCH] cleanup --- src/brevitas/proxy/parameter_quant.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 97f495b2e..e690d81b0 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -189,13 +189,13 @@ def forward(self, x: Tensor, input_scale: Optional[Tensor] = None) -> Union[Ten impl = self.export_handler if self.export_mode else self.tensor_quant if self.requires_input_scale and input_scale is None: raise RuntimeError("Input scale required") + if self.requires_input_scale: input_scale = input_scale.view(-1) out, out_scale, out_zp, out_bit_width = impl(x, input_scale) - elif not self.requires_input_scale: - out, out_scale, out_zp, out_bit_width = impl(x) else: - raise RuntimeError("Internally defined bit-width required") + out, out_scale, out_zp, out_bit_width = impl(x) + return QuantTensor(out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) else: return x