Skip to content

Commit

Permalink
Feat (export/onnx): dtype support for DequantizeLinear
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 8, 2023
1 parent 422b632 commit 50ea56a
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 18 deletions.
35 changes: 35 additions & 0 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ def clip_fn(self, x, min_val, max_val):
pass


class CastMixin(ABC):

@abstractmethod
def cast_fn(self, x, dtype):
pass


class QMixin(BitWidthHandlerMixin, ABC):

@classmethod
Expand Down Expand Up @@ -153,9 +160,16 @@ def symbolic_execution(self, x: Tensor):
scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape')
# Workaround to trick the tracer into believing all return values are used
self.assert_ge_zero(scale, zero_point, bit_width)
scale_dtype = scale.dtype
if scale.dtype != torch.float32:
scale = self.cast_fn(scale, torch.float32)
dequantize_symbolic_kwargs['scale'] = scale
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values())
if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
x = self.cast_fn(x, scale_dtype)
scale = self.cast_fn(scale, scale_dtype)
# Restore the original shapes to guarantee correct shape propagation downstream
scale = scale.view(scale_orig_shape)
zero_point = zero_point.view_as(scale)
Expand Down Expand Up @@ -221,10 +235,19 @@ def symbolic_execution(self, x: Tensor):
bit_width = self.symbolic_kwargs['bit_width']
# Workaround to trick the tracer into believing all return values are used
self.assert_ge_zero(scale, zero_point, bit_width)
scale_dtype = scale.dtype
if x.type != torch.float32:
x = self.cast_fn(x, torch.float32)
scale = self.cast_fn(scale, torch.float32)
quantize_symbolic_kwargs['scale'] = scale
dequantize_symbolic_kwargs['scale'] = scale
x = self.quantize_fn(x, *quantize_symbolic_kwargs.values())
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values())
if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
x = self.cast_fn(x, scale_dtype)
scale = self.cast_fn(scale, scale_dtype)
# Restore the original shapes to guarantee correct shape propagation downstream
scale = scale.view(scale_orig_shape)
zero_point = zero_point.view_as(scale)
Expand Down Expand Up @@ -275,7 +298,13 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None):
zero_point = to_0dim_if_scalar(zero_point).expand_as(scale)
zero_point = self.zero_point_with_dtype(
True, bit_width, zero_point) # assume signed is True
scale_dtype = scale.dtype
if scale.dtype != torch.float32:
scale = self.cast_fn(scale, torch.float32)
y = self.dequantize_fn(int_bias, scale, zero_point, quant_axis)
if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
y = self.cast_fn(y, scale_dtype)
scale = self.cast_fn(scale, scale_dtype)
# Restore the original shapes to guarantee correct shape propagation downstream
scale = scale.view(scale_orig_shape)
zero_point = zero_point.view_as(scale)
Expand Down Expand Up @@ -311,5 +340,11 @@ def symbolic_execution(
signed=signed, narrow=False, bit_width=output_bit_width)
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
flat_scale_dtype = flat_scale.dtype
if flat_scale_dtype != torch.float32:
flat_scale = self.cast_fn(flat_scale, torch.float32)
x = self.dequantize_fn(x, flat_scale, zp, self.quant_axis(scale))
if scale.dtype == torch.float16 or scale.dtype == torch.bfloat16:
x = self.cast_fn(x, scale.dtype)
flat_scale = self.cast_fn(flat_scale, flat_scale_dtype)
return x, scale, zero_point, output_bit_width
20 changes: 20 additions & 0 deletions src/brevitas/export/onnx/standard/function.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import onnx
import torch
from torch.autograd import Function
from torch.onnx.symbolic_helper import _get_tensor_sizes

from brevitas.export.onnx import onnx_export_opset

AXIS_OPSET = 13

DATATYPE_DICT = {
torch.float32: onnx.TensorProto.DataType.FLOAT,
torch.float16: onnx.TensorProto.DataType.FLOAT16,
torch.bfloat16: onnx.TensorProto.DataType.BFLOAT16}


class DequantizeLinearFn(Function):

Expand Down Expand Up @@ -39,6 +47,18 @@ def forward(ctx, int_x, min_int_val, max_int_val):
return int_x


class CastFn(Function):

@staticmethod
def symbolic(g, x, dtype):
ret = g.op('Cast', x, to_i=DATATYPE_DICT[dtype])
return ret

@staticmethod
def forward(ctx, x, dtype):
return x.to(dtype)


class QuantizeLinearFn(Function):

@staticmethod
Expand Down
24 changes: 16 additions & 8 deletions src/brevitas/export/onnx/standard/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch

from brevitas.export.common.handler.qcdq import CastMixin
from brevitas.export.common.handler.qcdq import CDQMixin
from brevitas.export.common.handler.qcdq import DQMixin
from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin
Expand All @@ -17,6 +18,7 @@
from brevitas.export.onnx.handler import ONNXBaseHandler
from brevitas.export.onnx.handler import QuantLSTMLayerHandler

from ..function import CastFn
from ..function import DequantizeLinearFn
from ..function import IntClipFn
from ..function import QuantizeLinearFn
Expand All @@ -39,13 +41,19 @@ def validate(self, module):
assert module.bit_width() > 1., 'Binary quant not supported'


class StdCDQONNXMixin(CDQMixin, StdDQONNXMixin, ABC):
class StdDQCastONNXMixin(CastMixin, StdDQONNXMixin):

def cast_fn(self, x, dtype):
return CastFn.apply(x, dtype)


class StdCDQCastONNXMixin(CDQMixin, StdDQCastONNXMixin, ABC):

def clip_fn(self, x, min_val, max_val):
return IntClipFn.apply(x, min_val, max_val)


class StdQCDQONNXMixin(QMixin, StdCDQONNXMixin, ABC):
class StdQCDQCastONNXMixin(QMixin, StdCDQCastONNXMixin, ABC):

@classmethod
def int8_dtype(cls):
Expand All @@ -70,36 +78,36 @@ def quantize_fn(self, x, scale, zero_point, dtype, axis):
return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis)


class StdQCDQONNXWeightQuantProxyHandler(StdCDQONNXMixin,
class StdQCDQONNXWeightQuantProxyHandler(StdCDQCastONNXMixin,
QCDQWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXDecoupledWeightQuantProxyHandler(StdCDQONNXMixin,
class StdQCDQONNXDecoupledWeightQuantProxyHandler(StdCDQCastONNXMixin,
QCDQDecoupledWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXDecoupledWeightQuantWithInputProxyHandler(
StdCDQONNXMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler):
StdCDQCastONNXMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler):
pass


class StdQCDQONNXActQuantProxyHandler(StdQCDQONNXMixin,
class StdQCDQONNXActQuantProxyHandler(StdQCDQCastONNXMixin,
QCDQActQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXBiasQuantProxyHandler(StdDQONNXMixin,
class StdQCDQONNXBiasQuantProxyHandler(StdDQCastONNXMixin,
QCDQBiasQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXTruncQuantProxyHandler(StdQCDQONNXMixin,
class StdQCDQONNXTruncQuantProxyHandler(StdQCDQCastONNXMixin,
QCDQTruncQuantProxyHandlerMixin,
ONNXBaseHandler):
pass
Expand Down
11 changes: 9 additions & 2 deletions src/brevitas/export/torch/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ def validate(self, module):
assert module.bit_width() > 1., 'Binary quant not supported'


class TorchCDQMixin(TorchDQMixin, ABC):
class TorchDQCastMixin(TorchDQMixin, ABC):

def cast_fn(self, x, dtype):
return x.type(dtype)


class TorchCDQMixin(TorchDQCastMixin, ABC):

def clip_fn(self, x, min_val, max_val):
return torch.clamp(x, min_val, max_val)
Expand Down Expand Up @@ -128,7 +134,8 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
return _itemize_clip_bounds(clip_args)


class TorchQCDQBiasQuantProxyHandler(TorchDQMixin, QCDQBiasQuantProxyHandlerMixin,
class TorchQCDQBiasQuantProxyHandler(TorchDQCastMixin,
QCDQBiasQuantProxyHandlerMixin,
TorchQCDQHandler):
pass

Expand Down
11 changes: 7 additions & 4 deletions tests/brevitas_ort/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from brevitas.export import export_onnx_qcdq
from brevitas.export import export_onnx_qop
from brevitas.export import export_qonnx
from brevitas.export import export_torch_qcdq
from brevitas.nn import QuantConv1d
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantConvTranspose1d
Expand Down Expand Up @@ -114,10 +115,12 @@ def recursive_allclose(ort_output, brevitas_output, tolerance):
def is_brevitas_ort_close(
model, np_input, export_name, export_type, tolerance=None, first_output_only=False):
input_t = torch.from_numpy(np_input)
brevitas_output = model(input_t)
model.cuda()
input_t = input_t.cuda()
# brevitas_output = model(input_t)

if tolerance is not None and export_type == 'qcdq':
tolerance = tolerance * brevitas_output.scale # Float Output, tolerance is +/- output scale
# if tolerance is not None and export_type == 'qcdq':
# tolerance = tolerance * brevitas_output.scale # Float Output, tolerance is +/- output scale

if export_type == 'qonnx':
exported_model = export_qonnx(model, input_t, export_path=export_name)
Expand All @@ -131,7 +134,7 @@ def is_brevitas_ort_close(
export_onnx_qop(model, input_t, export_path=export_name)
brevitas_output = brevitas_output.int(float_datatype=False)
elif export_type == 'qcdq':
export_onnx_qcdq(model, input_t, export_path=export_name)
export_onnx_qcdq(model, input_t, export_path=export_name, opset_version=19)
elif export_type == 'qcdq_opset14':
export_onnx_qcdq(model, input_t, opset_version=14, export_path=export_name)
elif export_type == 'qonnx_opset14':
Expand Down
10 changes: 6 additions & 4 deletions tests/brevitas_ort/test_quant_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@

@parametrize_with_cases('model', cases=QuantWBIOLCases)
@pytest.mark.parametrize('export_type', ['qcdq', 'qonnx', 'qop'])
@pytest.mark.parametrize('dtype', ['float16', 'float'])
@requires_pt_ge('1.8.1')
def test_ort_wbiol(model, export_type, current_cases):
def test_ort_wbiol(model, export_type, dtype, current_cases):
cases_generator_func = current_cases['model'][1]
case_id = get_case_id(cases_generator_func)
impl = case_id.split('-')[
-2] # Inverse list of definition, 'export_type' is -1, 'impl' is -2, etc.
quantizer = case_id.split('-')[-6]
-3] # Inverse list of definition, 'export_type' is -1, 'impl' is -2, etc.
quantizer = case_id.split('-')[-7]

if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d') and export_type == 'qop':
pytest.skip('Export of ConvTranspose is not supported for QOperation')
Expand All @@ -43,9 +44,10 @@ def test_ort_wbiol(model, export_type, current_cases):
in_size = (1, IN_CH, FEATURES, FEATURES)

inp = gen_linspaced_data(reduce(mul, in_size), -1, 1).reshape(in_size)

model(torch.from_numpy(inp)) # accumulate scale factors
model.eval()
inp = inp.astype(getattr(np, dtype))
model = model.to(getattr(torch, dtype))
export_name = f'qcdq_qop_export_{case_id}.onnx'
assert is_brevitas_ort_close(
model, inp, export_name, export_type, tolerance=INT_TOLERANCE, first_output_only=True)
Expand Down

0 comments on commit 50ea56a

Please sign in to comment.