Skip to content

Commit

Permalink
Review
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 30, 2024
1 parent 00c45f4 commit 5cc8b95
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
5 changes: 4 additions & 1 deletion src/brevitas/export/onnx/qonnx/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def symbolic(
exponent_bias,
has_inf,
has_nan,
saturating,
has_subnormal,
rounding_mode,
max_val):
Expand All @@ -81,11 +82,12 @@ def symbolic(
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
max_val,
has_inf_i=int(has_inf),
has_nan_i=int(has_nan),
has_subnormal_i=int(has_subnormal),
rounding_mode_s=rounding_mode,
max_val_f=max_val)
saturation_i=saturating)
ret.setType(x.type())
return ret

Expand All @@ -99,6 +101,7 @@ def forward(
exponent_bias,
has_inf,
has_nan,
saturating,
has_subnormal,
rounding_mode,
max_val):
Expand Down
7 changes: 3 additions & 4 deletions src/brevitas/export/onnx/qonnx/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def prepare_for_export(self, module):
'exponent_bias': module.exponent_bias(),
'has_inf': module.inf_values() is not None,
'has_nan': module.nan_values() is not None,
'saturating': module.is_saturating(),
'has_subnormal': True, # Currently we only support subnormal
'rounding_mode': module.rounding_mode,
'max_float': module.quant_injector.max_available_float}
Expand All @@ -54,9 +55,6 @@ def prepare_for_export(self, module):
'nan_values': module.nan_values(),}

def symbolic_execution(self, x: Tensor):
xx = tuple(self.symbolic_kwargs.values())
scale = self.symbolic_kwargs['scale']
print(self.symbolic_kwargs.values())
x = BrevitasFloatQuantFn.apply(x, *self.symbolic_kwargs.values())
return x, *self.return_args.values()

Expand All @@ -69,7 +67,7 @@ def __init__(self):
self.quant_weights = None

def validate(self, zero_point):
assert zero_point == 0, "Zero-point not supported for binary quant."
assert zero_point == 0, "Zero-point not supported for minifloat quant."

def prepare_for_export(self, module: WeightQuantProxyFromInjector):
if module.is_quant_enabled:
Expand All @@ -82,6 +80,7 @@ def prepare_for_export(self, module: WeightQuantProxyFromInjector):
'exponent_bias': first_qweight.exponent_bias,
'has_inf': first_qweight.inf_values is not None,
'has_nan': first_qweight.nan_values is not None,
'saturating': first_qweight.saturating,
'has_subnormal': True, # Currently we only support subnormal
'rounding_mode': module.rounding_mode,
'max_float': module.quant_injector.max_available_float,}
Expand Down

0 comments on commit 5cc8b95

Please sign in to comment.