Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (export): (b)float16 support for qcdq export #776

Merged
merged 3 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 60 additions & 6 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ def assert_ge_zero(self, *args):
assert bools


class DQCastMixin(DQMixin, ABC):

@abstractmethod
def cast_fn(self, x, dtype):
pass


class CDQMixin(DQMixin, ABC):

@abstractmethod
Expand Down Expand Up @@ -129,15 +136,21 @@ def prepare_for_export(self, module):
for tm in module.tracked_module_list}
# Get the first quant weight as representative
quant_weight = module.tracked_module_list[0].quant_weight()

# (B)float16 is not supported with standard Q/DQ ops, thus we store the original dtype
# of the scale and we cast it to float32.
# The original dtype is then restored during the forward pass
scale = quant_weight.scale
self.scale_dtype = scale.dtype
if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16:
scale = self.cast_fn(scale, torch.float32)

self.symbolic_kwargs['int_weights'] = int_weights
self.symbolic_kwargs['bit_width'] = quant_weight.bit_width
self.symbolic_kwargs['clip_symbolic_kwargs'] = self.int_clip_symbolic_kwargs(
module.is_narrow_range, module.is_signed, quant_weight.bit_width)
self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs(
quant_weight.scale,
quant_weight.zero_point,
quant_weight.bit_width,
module.is_signed)
scale, quant_weight.zero_point, quant_weight.bit_width, module.is_signed)
else:
self.symbolic_kwargs = None

Expand All @@ -156,6 +169,10 @@ def symbolic_execution(self, x: Tensor):
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values())
# After dequantization, cast both input and scale to the correct dtype
if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16:
x = self.cast_fn(x, self.scale_dtype)
scale = self.cast_fn(scale, self.scale_dtype)
# Restore the original shapes to guarantee correct shape propagation downstream
scale = scale.view(scale_orig_shape)
zero_point = zero_point.view_as(scale)
Expand Down Expand Up @@ -200,12 +217,22 @@ def prepare_for_export(self, module):
if module.is_quant_enabled:
self.validate(module)
self.symbolic_kwargs['bit_width'] = module.bit_width()

# (B)float16 is not supported with standard Q/DQ ops, thus we store the original dtype
# of the scale and we cast it to float32.
# The original dtype is then restored during the forward pass
scale = module.scale()
self.scale_dtype = scale.dtype
if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16:
scale = self.cast_fn(scale, torch.float32)

self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs(
module.scale(), module.zero_point(), module.bit_width(), module.is_signed)
scale, module.zero_point(), module.bit_width(), module.is_signed)
self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs(
module.scale(), module.zero_point(), module.bit_width(), module.is_signed)
scale, module.zero_point(), module.bit_width(), module.is_signed)
self.symbolic_kwargs['clip_symbolic_kwargs'] = self.int_clip_symbolic_kwargs(
module.is_narrow_range, module.is_signed, module.bit_width())

else:
self.symbolic_kwargs = None

Expand All @@ -221,10 +248,17 @@ def symbolic_execution(self, x: Tensor):
bit_width = self.symbolic_kwargs['bit_width']
# Workaround to trick the tracer into believing all return values are used
self.assert_ge_zero(scale, zero_point, bit_width)
# If original dtype of the input is (b)float16, cast the input to float32
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
x = self.cast_fn(x, torch.float32)
x = self.quantize_fn(x, *quantize_symbolic_kwargs.values())
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values())
# After dequantization, cast both output and scale to the correct dtype
if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16:
x = self.cast_fn(x, self.scale_dtype)
scale = self.cast_fn(scale, self.scale_dtype)
# Restore the original shapes to guarantee correct shape propagation downstream
scale = scale.view(scale_orig_shape)
zero_point = zero_point.view_as(scale)
Expand Down Expand Up @@ -275,7 +309,16 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None):
zero_point = to_0dim_if_scalar(zero_point).expand_as(scale)
zero_point = self.zero_point_with_dtype(
True, bit_width, zero_point) # assume signed is True
# If original dtype of scale is (b)float16, store the original dtype
# and cast the scale to float32
scale_dtype = scale.dtype
if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
scale = self.cast_fn(scale, torch.float32)
y = self.dequantize_fn(int_bias, scale, zero_point, quant_axis)
# After dequantization, cast both output and scale to the correct dtype
if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
y = self.cast_fn(y, scale_dtype)
scale = self.cast_fn(scale, scale_dtype)
# Restore the original shapes to guarantee correct shape propagation downstream
scale = scale.view(scale_orig_shape)
zero_point = zero_point.view_as(scale)
Expand All @@ -302,6 +345,13 @@ def symbolic_execution(
output_bit_width = self.symbolic_kwargs['output_bit_width']
dtype = self.int8_dtype() if signed else self.uint8_dtype()
trunc_scale = 2.0 ** (input_bit_width - output_bit_width)
# If original dtype of scale is (b)float16, store the original scale dtype
# and cast the scale and the input to float32
scale_dtype = scale.dtype
if scale_dtype == torch.bfloat16 or scale_dtype == torch.float16:
scale = self.cast_fn(scale, torch.float32)
if x.dtype == torch.bfloat16 or x.dtype == torch.float16:
x = self.cast_fn(x, torch.float32)
pre_scale = scale * trunc_scale
flat_pre_scale = to_0dim_if_scalar(pre_scale.flatten())
flat_scale = to_0dim_if_scalar(scale.flatten())
Expand All @@ -312,4 +362,8 @@ def symbolic_execution(
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, flat_scale, zp, self.quant_axis(scale))
# After dequantization, cast both output and scale to the correct dtype
if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
x = self.cast_fn(x, scale_dtype)
scale = self.cast_fn(scale, scale_dtype)
return x, scale, zero_point, output_bit_width
19 changes: 19 additions & 0 deletions src/brevitas/export/onnx/standard/function.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import onnx
import torch
from torch.autograd import Function

from brevitas.export.onnx import onnx_export_opset

AXIS_OPSET = 13

DATATYPE_DICT = {
torch.float32: onnx.TensorProto.DataType.FLOAT,
torch.float16: onnx.TensorProto.DataType.FLOAT16,
torch.bfloat16: onnx.TensorProto.DataType.BFLOAT16}


class DequantizeLinearFn(Function):

Expand Down Expand Up @@ -39,6 +46,18 @@ def forward(ctx, int_x, min_int_val, max_int_val):
return int_x


class CastFn(Function):

@staticmethod
def symbolic(g, x, dtype):
ret = g.op('Cast', x, to_i=DATATYPE_DICT[dtype])
return ret

@staticmethod
def forward(ctx, x, dtype):
return x.to(dtype)


class QuantizeLinearFn(Function):

@staticmethod
Expand Down
48 changes: 26 additions & 22 deletions src/brevitas/export/onnx/standard/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from brevitas.export.common.handler.qcdq import CDQMixin
from brevitas.export.common.handler.qcdq import DQMixin
from brevitas.export.common.handler.qcdq import DQCastMixin
from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin
Expand All @@ -17,16 +17,20 @@
from brevitas.export.onnx.handler import ONNXBaseHandler
from brevitas.export.onnx.handler import QuantLSTMLayerHandler

from ..function import CastFn
from ..function import DequantizeLinearFn
from ..function import IntClipFn
from ..function import QuantizeLinearFn


class StdDQONNXMixin(DQMixin, ABC):
class StdDQCastONNXMixin(DQCastMixin, ABC):

def dequantize_fn(self, x, scale, zero_point, axis):
return DequantizeLinearFn.apply(x, scale, zero_point, axis)

def cast_fn(self, x, dtype):
return CastFn.apply(x, dtype)

@property
def flatten_dequantize_params(self):
return True
Expand All @@ -39,13 +43,13 @@ def validate(self, module):
assert module.bit_width() > 1., 'Binary quant not supported'


class StdCDQONNXMixin(CDQMixin, StdDQONNXMixin, ABC):
class StdCDQCastONNXMixin(CDQMixin, StdDQCastONNXMixin, ABC):

def clip_fn(self, x, min_val, max_val):
return IntClipFn.apply(x, min_val, max_val)


class StdQCDQONNXMixin(QMixin, StdCDQONNXMixin, ABC):
class StdQCDQCastONNXMixin(QMixin, StdCDQCastONNXMixin, ABC):

@classmethod
def int8_dtype(cls):
Expand All @@ -70,42 +74,42 @@ def quantize_fn(self, x, scale, zero_point, dtype, axis):
return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis)


class StdQCDQONNXWeightQuantProxyHandler(StdCDQONNXMixin,
QCDQWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
class StdQCDQCastONNXWeightQuantProxyHandler(StdCDQCastONNXMixin,
QCDQWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXDecoupledWeightQuantProxyHandler(StdCDQONNXMixin,
QCDQDecoupledWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
class StdQCDQCastONNXDecoupledWeightQuantProxyHandler(StdCDQCastONNXMixin,
QCDQDecoupledWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXDecoupledWeightQuantWithInputProxyHandler(
StdCDQONNXMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler):
class StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler(
StdCDQCastONNXMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler):
pass


class StdQCDQONNXActQuantProxyHandler(StdQCDQONNXMixin,
QCDQActQuantProxyHandlerMixin,
ONNXBaseHandler):
class StdQCDQCastONNXActQuantProxyHandler(StdQCDQCastONNXMixin,
QCDQActQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXBiasQuantProxyHandler(StdDQONNXMixin,
QCDQBiasQuantProxyHandlerMixin,
ONNXBaseHandler):
class StdQCDQCastONNXBiasQuantProxyHandler(StdDQCastONNXMixin,
QCDQBiasQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXTruncQuantProxyHandler(StdQCDQONNXMixin,
QCDQTruncQuantProxyHandlerMixin,
ONNXBaseHandler):
class StdQCDQCastONNXTruncQuantProxyHandler(StdQCDQCastONNXMixin,
QCDQTruncQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXQuantLSTMLayerHandler(QuantLSTMLayerHandler):
class StdQCDQCastONNXQuantLSTMLayerHandler(QuantLSTMLayerHandler):

def quantized_cell_symbolic_execution(
self,
Expand Down
28 changes: 14 additions & 14 deletions src/brevitas/export/onnx/standard/qcdq/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
from ..function import IntClipFn
from ..function import QuantizeLinearFn
from ..manager import StdONNXBaseManager
from .handler import StdQCDQONNXActQuantProxyHandler
from .handler import StdQCDQONNXBiasQuantProxyHandler
from .handler import StdQCDQONNXDecoupledWeightQuantProxyHandler
from .handler import StdQCDQONNXDecoupledWeightQuantWithInputProxyHandler
from .handler import StdQCDQONNXQuantLSTMLayerHandler
from .handler import StdQCDQONNXTruncQuantProxyHandler
from .handler import StdQCDQONNXWeightQuantProxyHandler
from .handler import StdQCDQCastONNXActQuantProxyHandler
from .handler import StdQCDQCastONNXBiasQuantProxyHandler
from .handler import StdQCDQCastONNXDecoupledWeightQuantProxyHandler
from .handler import StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler
from .handler import StdQCDQCastONNXQuantLSTMLayerHandler
from .handler import StdQCDQCastONNXTruncQuantProxyHandler
from .handler import StdQCDQCastONNXWeightQuantProxyHandler


class StdQCDQONNXManager(StdONNXBaseManager):
Expand All @@ -33,13 +33,13 @@ class StdQCDQONNXManager(StdONNXBaseManager):
"eliminate_unused_initializer"]

handlers = [
StdQCDQONNXWeightQuantProxyHandler,
StdQCDQONNXBiasQuantProxyHandler,
StdQCDQONNXActQuantProxyHandler,
StdQCDQONNXDecoupledWeightQuantProxyHandler,
StdQCDQONNXTruncQuantProxyHandler,
StdQCDQONNXDecoupledWeightQuantWithInputProxyHandler,
StdQCDQONNXQuantLSTMLayerHandler]
StdQCDQCastONNXWeightQuantProxyHandler,
StdQCDQCastONNXBiasQuantProxyHandler,
StdQCDQCastONNXActQuantProxyHandler,
StdQCDQCastONNXDecoupledWeightQuantProxyHandler,
StdQCDQCastONNXTruncQuantProxyHandler,
StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler,
StdQCDQCastONNXQuantLSTMLayerHandler]

custom_fns = [
DebugMarkerFunction,
Expand Down
Loading
Loading