diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 1aecaaeb4..a91740808 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -58,6 +58,13 @@ def clip_fn(self, x, min_val, max_val): pass +class CastMixin(ABC): + + @abstractmethod + def cast_fn(self, x, dtype): + pass + + class QMixin(BitWidthHandlerMixin, ABC): @classmethod @@ -153,9 +160,16 @@ def symbolic_execution(self, x: Tensor): scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape') # Workaround to trick the tracer into believing all return values are used self.assert_ge_zero(scale, zero_point, bit_width) + scale_dtype = scale.dtype + if scale.dtype != torch.float32: + scale = self.cast_fn(scale, torch.float32) + dequantize_symbolic_kwargs['scale'] = scale if clip_symbolic_kwargs is not None: x = self.clip_fn(x, *clip_symbolic_kwargs.values()) x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values()) + if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16: + x = self.cast_fn(x, scale_dtype) + scale = self.cast_fn(scale, scale_dtype) # Restore the original shapes to guarantee correct shape propagation downstream scale = scale.view(scale_orig_shape) zero_point = zero_point.view_as(scale) @@ -221,10 +235,19 @@ def symbolic_execution(self, x: Tensor): bit_width = self.symbolic_kwargs['bit_width'] # Workaround to trick the tracer into believing all return values are used self.assert_ge_zero(scale, zero_point, bit_width) + scale_dtype = scale.dtype + if x.type != torch.float32: + x = self.cast_fn(x, torch.float32) + scale = self.cast_fn(scale, torch.float32) + quantize_symbolic_kwargs['scale'] = scale + dequantize_symbolic_kwargs['scale'] = scale x = self.quantize_fn(x, *quantize_symbolic_kwargs.values()) if clip_symbolic_kwargs is not None: x = self.clip_fn(x, *clip_symbolic_kwargs.values()) x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values()) + if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16: + x = self.cast_fn(x, scale_dtype) + scale = self.cast_fn(scale, scale_dtype) # Restore the original shapes to guarantee correct shape propagation downstream scale = scale.view(scale_orig_shape) zero_point = zero_point.view_as(scale) @@ -275,7 +298,13 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None): zero_point = to_0dim_if_scalar(zero_point).expand_as(scale) zero_point = self.zero_point_with_dtype( True, bit_width, zero_point) # assume signed is True + scale_dtype = scale.dtype + if scale.dtype != torch.float32: + scale = self.cast_fn(scale, torch.float32) y = self.dequantize_fn(int_bias, scale, zero_point, quant_axis) + if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16: + y = self.cast_fn(y, scale_dtype) + scale = self.cast_fn(scale, scale_dtype) # Restore the original shapes to guarantee correct shape propagation downstream scale = scale.view(scale_orig_shape) zero_point = zero_point.view_as(scale) @@ -311,5 +340,11 @@ def symbolic_execution( signed=signed, narrow=False, bit_width=output_bit_width) if clip_symbolic_kwargs is not None: x = self.clip_fn(x, *clip_symbolic_kwargs.values()) + flat_scale_dtype = flat_scale.dtype + if flat_scale_dtype != torch.float32: + flat_scale = self.cast_fn(flat_scale, torch.float32) x = self.dequantize_fn(x, flat_scale, zp, self.quant_axis(scale)) + if scale.dtype == torch.float16 or scale.dtype == torch.bfloat16: + x = self.cast_fn(x, scale.dtype) + flat_scale = self.cast_fn(flat_scale, flat_scale_dtype) return x, scale, zero_point, output_bit_width diff --git a/src/brevitas/export/onnx/standard/function.py b/src/brevitas/export/onnx/standard/function.py index 1a1de5b15..9889720a7 100644 --- a/src/brevitas/export/onnx/standard/function.py +++ b/src/brevitas/export/onnx/standard/function.py @@ -1,12 +1,20 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import onnx +import torch from torch.autograd import Function +from torch.onnx.symbolic_helper import _get_tensor_sizes from brevitas.export.onnx import onnx_export_opset AXIS_OPSET = 13 +DATATYPE_DICT = { + torch.float32: onnx.TensorProto.DataType.FLOAT, + torch.float16: onnx.TensorProto.DataType.FLOAT16, + torch.bfloat16: onnx.TensorProto.DataType.BFLOAT16} + class DequantizeLinearFn(Function): @@ -39,6 +47,18 @@ def forward(ctx, int_x, min_int_val, max_int_val): return int_x +class CastFn(Function): + + @staticmethod + def symbolic(g, x, dtype): + ret = g.op('Cast', x, to_i=DATATYPE_DICT[dtype]) + return ret + + @staticmethod + def forward(ctx, x, dtype): + return x.to(dtype) + + class QuantizeLinearFn(Function): @staticmethod diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index 165f4ad6c..f0017b841 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -5,6 +5,7 @@ import torch +from brevitas.export.common.handler.qcdq import CastMixin from brevitas.export.common.handler.qcdq import CDQMixin from brevitas.export.common.handler.qcdq import DQMixin from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin @@ -17,6 +18,7 @@ from brevitas.export.onnx.handler import ONNXBaseHandler from brevitas.export.onnx.handler import QuantLSTMLayerHandler +from ..function import CastFn from ..function import DequantizeLinearFn from ..function import IntClipFn from ..function import QuantizeLinearFn @@ -39,13 +41,19 @@ def validate(self, module): assert module.bit_width() > 1., 'Binary quant not supported' -class StdCDQONNXMixin(CDQMixin, StdDQONNXMixin, ABC): +class StdDQCastONNXMixin(CastMixin, StdDQONNXMixin): + + def cast_fn(self, x, dtype): + return CastFn.apply(x, dtype) + + +class StdCDQCastONNXMixin(CDQMixin, StdDQCastONNXMixin, ABC): def clip_fn(self, x, min_val, max_val): return IntClipFn.apply(x, min_val, max_val) -class StdQCDQONNXMixin(QMixin, StdCDQONNXMixin, ABC): +class StdQCDQCastONNXMixin(QMixin, StdCDQCastONNXMixin, ABC): @classmethod def int8_dtype(cls): @@ -70,36 +78,36 @@ def quantize_fn(self, x, scale, zero_point, dtype, axis): return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis) -class StdQCDQONNXWeightQuantProxyHandler(StdCDQONNXMixin, +class StdQCDQONNXWeightQuantProxyHandler(StdCDQCastONNXMixin, QCDQWeightQuantProxyHandlerMixin, ONNXBaseHandler): pass -class StdQCDQONNXDecoupledWeightQuantProxyHandler(StdCDQONNXMixin, +class StdQCDQONNXDecoupledWeightQuantProxyHandler(StdCDQCastONNXMixin, QCDQDecoupledWeightQuantProxyHandlerMixin, ONNXBaseHandler): pass class StdQCDQONNXDecoupledWeightQuantWithInputProxyHandler( - StdCDQONNXMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler): + StdCDQCastONNXMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler): pass -class StdQCDQONNXActQuantProxyHandler(StdQCDQONNXMixin, +class StdQCDQONNXActQuantProxyHandler(StdQCDQCastONNXMixin, QCDQActQuantProxyHandlerMixin, ONNXBaseHandler): pass -class StdQCDQONNXBiasQuantProxyHandler(StdDQONNXMixin, +class StdQCDQONNXBiasQuantProxyHandler(StdDQCastONNXMixin, QCDQBiasQuantProxyHandlerMixin, ONNXBaseHandler): pass -class StdQCDQONNXTruncQuantProxyHandler(StdQCDQONNXMixin, +class StdQCDQONNXTruncQuantProxyHandler(StdQCDQCastONNXMixin, QCDQTruncQuantProxyHandlerMixin, ONNXBaseHandler): pass diff --git a/src/brevitas/export/torch/qcdq/handler.py b/src/brevitas/export/torch/qcdq/handler.py index e80e20dac..309b98410 100644 --- a/src/brevitas/export/torch/qcdq/handler.py +++ b/src/brevitas/export/torch/qcdq/handler.py @@ -52,7 +52,13 @@ def validate(self, module): assert module.bit_width() > 1., 'Binary quant not supported' -class TorchCDQMixin(TorchDQMixin, ABC): +class TorchDQCastMixin(TorchDQMixin, ABC): + + def cast_fn(self, x, dtype): + return x.type(dtype) + + +class TorchCDQMixin(TorchDQCastMixin, ABC): def clip_fn(self, x, min_val, max_val): return torch.clamp(x, min_val, max_val) @@ -128,7 +134,8 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): return _itemize_clip_bounds(clip_args) -class TorchQCDQBiasQuantProxyHandler(TorchDQMixin, QCDQBiasQuantProxyHandlerMixin, +class TorchQCDQBiasQuantProxyHandler(TorchDQCastMixin, + QCDQBiasQuantProxyHandlerMixin, TorchQCDQHandler): pass