diff --git a/src/brevitas/export/onnx/standard/function.py b/src/brevitas/export/onnx/standard/function.py index 1a1de5b15..0b17bb31e 100644 --- a/src/brevitas/export/onnx/standard/function.py +++ b/src/brevitas/export/onnx/standard/function.py @@ -6,6 +6,7 @@ from brevitas.export.onnx import onnx_export_opset AXIS_OPSET = 13 +DTYPE_OPSET = 19 class DequantizeLinearFn(Function): @@ -24,7 +25,13 @@ def symbolic(g, x, input_scale, input_zero_point, input_axis): @staticmethod def forward(ctx, int_x, input_scale, input_zero_point, input_axis): - return int_x.float() + opset_version = onnx_export_opset() + # If opset is less than DTYPE_OPSET, the output of DequantizeLinear is always float, + # otherwise it has the same dtype of the input_scale + if opset_version < DTYPE_OPSET: + return int_x.float() + else: + return int_x.to(input_scale.dtype) class IntClipFn(Function):