Skip to content

Commit

Permalink
Feat (export): (b)float16 support for qcdq export
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 11, 2023
1 parent 422b632 commit 84d25a6
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 16 deletions.
66 changes: 60 additions & 6 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ def clip_fn(self, x, min_val, max_val):
pass


class CastMixin(ABC):

@abstractmethod
def cast_fn(self, x, dtype):
pass


class QMixin(BitWidthHandlerMixin, ABC):

@classmethod
Expand Down Expand Up @@ -129,15 +136,21 @@ def prepare_for_export(self, module):
for tm in module.tracked_module_list}
# Get the first quant weight as representative
quant_weight = module.tracked_module_list[0].quant_weight()

# (B)float16 is not supported with standard Q/DQ ops, thus we store the original dtype
# of the scale and we cast it to float32.
# The original dtype is then restored during the forward pass
scale = quant_weight.scale
self.scale_dtype = scale.dtype
if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16:
scale = self.cast_fn(scale, torch.float32)

self.symbolic_kwargs['int_weights'] = int_weights
self.symbolic_kwargs['bit_width'] = quant_weight.bit_width
self.symbolic_kwargs['clip_symbolic_kwargs'] = self.int_clip_symbolic_kwargs(
module.is_narrow_range, module.is_signed, quant_weight.bit_width)
self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs(
quant_weight.scale,
quant_weight.zero_point,
quant_weight.bit_width,
module.is_signed)
scale, quant_weight.zero_point, quant_weight.bit_width, module.is_signed)
else:
self.symbolic_kwargs = None

Expand All @@ -156,6 +169,10 @@ def symbolic_execution(self, x: Tensor):
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values())
# After dequantization, cast both input and scale to the correct dtype
if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16:
x = self.cast_fn(x, self.scale_dtype)
scale = self.cast_fn(scale, self.scale_dtype)
# Restore the original shapes to guarantee correct shape propagation downstream
scale = scale.view(scale_orig_shape)
zero_point = zero_point.view_as(scale)
Expand Down Expand Up @@ -200,12 +217,22 @@ def prepare_for_export(self, module):
if module.is_quant_enabled:
self.validate(module)
self.symbolic_kwargs['bit_width'] = module.bit_width()

# (B)float16 is not supported with standard Q/DQ ops, thus we store the original dtype
# of the scale and we cast it to float32.
# The original dtype is then restored during the forward pass
scale = module.scale()
self.scale_dtype = scale.dtype
if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16:
scale = self.cast_fn(scale, torch.float32)

self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs(
module.scale(), module.zero_point(), module.bit_width(), module.is_signed)
scale, module.zero_point(), module.bit_width(), module.is_signed)
self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs(
module.scale(), module.zero_point(), module.bit_width(), module.is_signed)
scale, module.zero_point(), module.bit_width(), module.is_signed)
self.symbolic_kwargs['clip_symbolic_kwargs'] = self.int_clip_symbolic_kwargs(
module.is_narrow_range, module.is_signed, module.bit_width())

else:
self.symbolic_kwargs = None

Expand All @@ -221,10 +248,17 @@ def symbolic_execution(self, x: Tensor):
bit_width = self.symbolic_kwargs['bit_width']
# Workaround to trick the tracer into believing all return values are used
self.assert_ge_zero(scale, zero_point, bit_width)
# If original dtype of the input is (b)float16, cast the input to float32
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
x = self.cast_fn(x, torch.float32)
x = self.quantize_fn(x, *quantize_symbolic_kwargs.values())
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values())
# After dequantization, cast both output and scale to the correct dtype
if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16:
x = self.cast_fn(x, self.scale_dtype)
scale = self.cast_fn(scale, self.scale_dtype)
# Restore the original shapes to guarantee correct shape propagation downstream
scale = scale.view(scale_orig_shape)
zero_point = zero_point.view_as(scale)
Expand Down Expand Up @@ -275,7 +309,16 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None):
zero_point = to_0dim_if_scalar(zero_point).expand_as(scale)
zero_point = self.zero_point_with_dtype(
True, bit_width, zero_point) # assume signed is True
# If original dtype of scale is (b)float16, store the original dtype
# and cast the scale to float32
scale_dtype = scale.dtype
if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
scale = self.cast_fn(scale, torch.float32)
y = self.dequantize_fn(int_bias, scale, zero_point, quant_axis)
# After dequantization, cast both output and scale to the correct dtype
if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
y = self.cast_fn(y, scale_dtype)
scale = self.cast_fn(scale, scale_dtype)
# Restore the original shapes to guarantee correct shape propagation downstream
scale = scale.view(scale_orig_shape)
zero_point = zero_point.view_as(scale)
Expand All @@ -302,6 +345,13 @@ def symbolic_execution(
output_bit_width = self.symbolic_kwargs['output_bit_width']
dtype = self.int8_dtype() if signed else self.uint8_dtype()
trunc_scale = 2.0 ** (input_bit_width - output_bit_width)
# If original dtype of scale is (b)float16, store the original scale dtype
# and cast the scale and the input to float32
scale_dtype = scale.dtype
if scale_dtype == torch.bfloat16 or scale_dtype == torch.float16:
scale = self.cast_fn(scale, torch.float32)
if x.dtype == torch.bfloat16 or x.dtype == torch.float16:
x = self.cast_fn(x, torch.float32)
pre_scale = scale * trunc_scale
flat_pre_scale = to_0dim_if_scalar(pre_scale.flatten())
flat_scale = to_0dim_if_scalar(scale.flatten())
Expand All @@ -312,4 +362,8 @@ def symbolic_execution(
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, flat_scale, zp, self.quant_axis(scale))
# After dequantization, cast both output and scale to the correct dtype
if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
x = self.cast_fn(x, scale_dtype)
scale = self.cast_fn(scale, scale_dtype)
return x, scale, zero_point, output_bit_width
19 changes: 19 additions & 0 deletions src/brevitas/export/onnx/standard/function.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import onnx
import torch
from torch.autograd import Function

from brevitas.export.onnx import onnx_export_opset

AXIS_OPSET = 13

DATATYPE_DICT = {
torch.float32: onnx.TensorProto.DataType.FLOAT,
torch.float16: onnx.TensorProto.DataType.FLOAT16,
torch.bfloat16: onnx.TensorProto.DataType.BFLOAT16}


class DequantizeLinearFn(Function):

Expand Down Expand Up @@ -39,6 +46,18 @@ def forward(ctx, int_x, min_int_val, max_int_val):
return int_x


class CastFn(Function):

@staticmethod
def symbolic(g, x, dtype):
ret = g.op('Cast', x, to_i=DATATYPE_DICT[dtype])
return ret

@staticmethod
def forward(ctx, x, dtype):
return x.to(dtype)


class QuantizeLinearFn(Function):

@staticmethod
Expand Down
24 changes: 16 additions & 8 deletions src/brevitas/export/onnx/standard/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch

from brevitas.export.common.handler.qcdq import CastMixin
from brevitas.export.common.handler.qcdq import CDQMixin
from brevitas.export.common.handler.qcdq import DQMixin
from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin
Expand All @@ -17,6 +18,7 @@
from brevitas.export.onnx.handler import ONNXBaseHandler
from brevitas.export.onnx.handler import QuantLSTMLayerHandler

from ..function import CastFn
from ..function import DequantizeLinearFn
from ..function import IntClipFn
from ..function import QuantizeLinearFn
Expand All @@ -39,13 +41,19 @@ def validate(self, module):
assert module.bit_width() > 1., 'Binary quant not supported'


class StdCDQONNXMixin(CDQMixin, StdDQONNXMixin, ABC):
class StdDQCastONNXMixin(CastMixin, StdDQONNXMixin):

def cast_fn(self, x, dtype):
return CastFn.apply(x, dtype)


class StdCDQCastONNXMixin(CDQMixin, StdDQCastONNXMixin, ABC):

def clip_fn(self, x, min_val, max_val):
return IntClipFn.apply(x, min_val, max_val)


class StdQCDQONNXMixin(QMixin, StdCDQONNXMixin, ABC):
class StdQCDQCastONNXMixin(QMixin, StdCDQCastONNXMixin, ABC):

@classmethod
def int8_dtype(cls):
Expand All @@ -70,36 +78,36 @@ def quantize_fn(self, x, scale, zero_point, dtype, axis):
return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis)


class StdQCDQONNXWeightQuantProxyHandler(StdCDQONNXMixin,
class StdQCDQONNXWeightQuantProxyHandler(StdCDQCastONNXMixin,
QCDQWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXDecoupledWeightQuantProxyHandler(StdCDQONNXMixin,
class StdQCDQONNXDecoupledWeightQuantProxyHandler(StdCDQCastONNXMixin,
QCDQDecoupledWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXDecoupledWeightQuantWithInputProxyHandler(
StdCDQONNXMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler):
StdCDQCastONNXMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler):
pass


class StdQCDQONNXActQuantProxyHandler(StdQCDQONNXMixin,
class StdQCDQONNXActQuantProxyHandler(StdQCDQCastONNXMixin,
QCDQActQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXBiasQuantProxyHandler(StdDQONNXMixin,
class StdQCDQONNXBiasQuantProxyHandler(StdDQCastONNXMixin,
QCDQBiasQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXTruncQuantProxyHandler(StdQCDQONNXMixin,
class StdQCDQONNXTruncQuantProxyHandler(StdQCDQCastONNXMixin,
QCDQTruncQuantProxyHandlerMixin,
ONNXBaseHandler):
pass
Expand Down
11 changes: 9 additions & 2 deletions src/brevitas/export/torch/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ def validate(self, module):
assert module.bit_width() > 1., 'Binary quant not supported'


class TorchCDQMixin(TorchDQMixin, ABC):
class TorchDQCastMixin(TorchDQMixin, ABC):

def cast_fn(self, x, dtype):
return x.type(dtype)


class TorchCDQMixin(TorchDQCastMixin, ABC):

def clip_fn(self, x, min_val, max_val):
return torch.clamp(x, min_val, max_val)
Expand Down Expand Up @@ -128,7 +134,8 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
return _itemize_clip_bounds(clip_args)


class TorchQCDQBiasQuantProxyHandler(TorchDQMixin, QCDQBiasQuantProxyHandlerMixin,
class TorchQCDQBiasQuantProxyHandler(TorchDQCastMixin,
QCDQBiasQuantProxyHandlerMixin,
TorchQCDQHandler):
pass

Expand Down

0 comments on commit 84d25a6

Please sign in to comment.