diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 32ecaa81e..6cd2c03ed 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -227,7 +227,7 @@ def create_quant_tensor( qt_args: Union[Tensor, Tuple[Any]], x: Optional[IntQuantTensor] = None) -> IntQuantTensor: - if x is None: + if isinstance(qt_args, tuple): out = IntQuantTensor(*qt_args, self.is_signed, self.training) else: out = IntQuantTensor(