diff --git a/src/brevitas/export/onnx/qonnx/function.py b/src/brevitas/export/onnx/qonnx/function.py index 56701f6e8..5160572ef 100644 --- a/src/brevitas/export/onnx/qonnx/function.py +++ b/src/brevitas/export/onnx/qonnx/function.py @@ -71,6 +71,7 @@ def symbolic( exponent_bias, has_inf, has_nan, + saturating, has_subnormal, rounding_mode, max_val): @@ -81,11 +82,12 @@ def symbolic( exponent_bit_width, mantissa_bit_width, exponent_bias, + max_val, has_inf_i=int(has_inf), has_nan_i=int(has_nan), has_subnormal_i=int(has_subnormal), rounding_mode_s=rounding_mode, - max_val_f=max_val) + saturation_i=saturating) ret.setType(x.type()) return ret @@ -99,6 +101,7 @@ def forward( exponent_bias, has_inf, has_nan, + saturating, has_subnormal, rounding_mode, max_val): diff --git a/src/brevitas/export/onnx/qonnx/handler.py b/src/brevitas/export/onnx/qonnx/handler.py index 1f4d6780e..3b9a2d41c 100644 --- a/src/brevitas/export/onnx/qonnx/handler.py +++ b/src/brevitas/export/onnx/qonnx/handler.py @@ -40,6 +40,7 @@ def prepare_for_export(self, module): 'exponent_bias': module.exponent_bias(), 'has_inf': module.inf_values() is not None, 'has_nan': module.nan_values() is not None, + 'saturating': module.is_saturating(), 'has_subnormal': True, # Currently we only support subnormal 'rounding_mode': module.rounding_mode, 'max_float': module.quant_injector.max_available_float} @@ -54,9 +55,6 @@ def prepare_for_export(self, module): 'nan_values': module.nan_values(),} def symbolic_execution(self, x: Tensor): - xx = tuple(self.symbolic_kwargs.values()) - scale = self.symbolic_kwargs['scale'] - print(self.symbolic_kwargs.values()) x = BrevitasFloatQuantFn.apply(x, *self.symbolic_kwargs.values()) return x, *self.return_args.values() @@ -69,7 +67,7 @@ def __init__(self): self.quant_weights = None def validate(self, zero_point): - assert zero_point == 0, "Zero-point not supported for binary quant." + assert zero_point == 0, "Zero-point not supported for minifloat quant." def prepare_for_export(self, module: WeightQuantProxyFromInjector): if module.is_quant_enabled: @@ -82,6 +80,7 @@ def prepare_for_export(self, module: WeightQuantProxyFromInjector): 'exponent_bias': first_qweight.exponent_bias, 'has_inf': first_qweight.inf_values is not None, 'has_nan': first_qweight.nan_values is not None, + 'saturating': first_qweight.saturating, 'has_subnormal': True, # Currently we only support subnormal 'rounding_mode': module.rounding_mode, 'max_float': module.quant_injector.max_available_float,}