Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: support for optimum #826

Merged
merged 20 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/brevitas/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@ def env_to_bool(name, default):
VERBOSE = env_to_bool('BREVITAS_VERBOSE', False)

# Internal global variables
_FULL_STATE_DICT = False
_IS_INSIDE_QUANT_LAYER = None
_ONGOING_EXPORT = None
27 changes: 27 additions & 0 deletions src/brevitas/core/function_wrapper/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from brevitas.function.shape import over_batch_over_output_channels
from brevitas.function.shape import over_batch_over_tensor
from brevitas.function.shape import over_output_channels
from brevitas.function.shape import over_output_features
from brevitas.function.shape import over_tensor


Expand Down Expand Up @@ -126,6 +127,32 @@ def forward(self, x: torch.Tensor):
return y.reshape(shape)


class OverOutputFeaturesView(brevitas.jit.ScriptModule):
"""
ScriptModule to compute the :func:`~brevitas.function.shape.over_output_features`
view of an input tensor.

Examples:
>>> view_module = OverOutputFeaturesView()
>>> y = view_module(torch.empty(size=[8, 10, 25]))
>>> y.shape
torch.Size([80, 25])
"""

def __init__(self, permute_dims: Optional[Tuple[int, ...]] = None) -> None:
super(OverOutputFeaturesView, self).__init__()
if permute_dims is not None:
self.permute_impl = PermuteDims(permute_dims)
else:
self.permute_impl = Identity()

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
y = self.permute_impl(x)
shape = over_output_features(y)
return y.reshape(shape)


class StatsInputViewShapeImpl(object):
"""
Enum-like object to collect pointers to variants of ScriptModules that perform a view on a tensor.
Expand Down
7 changes: 5 additions & 2 deletions src/brevitas/core/stats/view_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.nn import Parameter

import brevitas
import brevitas.config as config
from brevitas.core.function_wrapper import StatsInputViewShapeImpl # retrocomp


Expand All @@ -33,7 +34,8 @@ def _load_from_state_dict(
def state_dict(self, destination=None, prefix='', keep_vars=False):
output_dict = super(_ViewParameterWrapper, self).state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars)
del output_dict[prefix + 'parameter']
if not config._FULL_STATE_DICT:
del output_dict[prefix + 'parameter']
return output_dict


Expand Down Expand Up @@ -62,5 +64,6 @@ def _load_from_state_dict(
def state_dict(self, destination=None, prefix='', keep_vars=False):
output_dict = super(_ViewCatParameterWrapper, self).state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars)
del output_dict[prefix + 'parameter']
if not config._FULL_STATE_DICT:
del output_dict[prefix + 'parameter']
return output_dict
64 changes: 55 additions & 9 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,46 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
'scale_orig_shape': scale_orig_shape}


class CDQCastWeightQuantProxyHandlerMixin(CDQCastProxyHandlerMixin, ABC):
class QCDQCastWeightQuantProxyHandlerMixin(QMixin, CDQCastProxyHandlerMixin):
handled_layer = WeightQuantProxyFromInjector
_export_q_node = False

def quantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
# compute axis before redefining scale
axis = cls.quant_axis(scale)
scale = to_0dim_if_scalar(scale.flatten())
zp = to_0dim_if_scalar(zero_point.flatten())
# expand_as must go after 0-dim check
zp = zp.expand_as(scale)
zp = cls.zero_point_with_dtype(is_signed, bit_width, zp)
if cls.itemize_quantize_scalar_params:
scale = to_item_if_0dim(scale)
zp = to_item_if_0dim(zp)
dtype = cls.signed_dtype(bit_width, is_signed)
return {'scale': scale, 'zero_point': zp, 'dtype': dtype, 'axis': axis}

def prepare_quantize_from_floating_point(self, module):
quant_weight = module.tracked_module_list[0].quant_weight()
scale = quant_weight.scale
self.scale_dtype = scale.dtype
if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16:
scale = self.cast_fn(scale, torch.float32)
self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs(
scale, quant_weight.zero_point, quant_weight.bit_width, module.is_signed)

def prepare_quantize_from_integer(self, module):
int_weights = {
tm.weight.data_ptr(): tm.quant_weight().int(float_datatype=False)
for tm in module.tracked_module_list}
self.symbolic_kwargs['int_weights'] = int_weights

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.validate(module)
int_weights = {
tm.weight.data_ptr(): tm.quant_weight().int(float_datatype=False)
for tm in module.tracked_module_list}
if self._export_q_node:
self.prepare_quantize_from_floating_point(module)
else:
self.prepare_quantize_from_integer(module)
# Get the first quant weight as representative
quant_weight = module.tracked_module_list[0].quant_weight()

Expand All @@ -153,7 +184,6 @@ def prepare_for_export(self, module):
if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16:
scale = self.cast_fn(scale, torch.float32)

self.symbolic_kwargs['int_weights'] = int_weights
self.symbolic_kwargs['bit_width'] = quant_weight.bit_width
self.symbolic_kwargs['clip_symbolic_kwargs'] = self.int_clip_symbolic_kwargs(
module.is_narrow_range, module.is_signed, quant_weight.bit_width)
Expand All @@ -162,9 +192,20 @@ def prepare_for_export(self, module):
else:
self.symbolic_kwargs = None

def quantize_from_floating_point(self, x: Tensor):
quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs']
x = self.quantize_fn(x, *quantize_symbolic_kwargs.values())
return x

def quantize_from_integer(self, x: Tensor):
return self.symbolic_kwargs['int_weights'][x.data_ptr()]

def symbolic_execution(self, x: Tensor):
assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled'
x = self.symbolic_kwargs['int_weights'][x.data_ptr()]
if self._export_q_node:
x = self.quantize_from_floating_point(x)
else:
x = self.quantize_from_integer(x)
clip_symbolic_kwargs = self.symbolic_kwargs['clip_symbolic_kwargs']
# Copy dict to allow for popping kwargs even on shared quantizers
dequantize_symbolic_kwargs = copy(self.symbolic_kwargs['dequantize_symbolic_kwargs'])
Expand All @@ -187,18 +228,23 @@ def symbolic_execution(self, x: Tensor):
return x, scale, zero_point, bit_width


class CDQCastDecoupledWeightQuantProxyHandlerMixin(CDQCastWeightQuantProxyHandlerMixin, ABC):
class QCDQCastDecoupledWeightQuantProxyHandlerMixin(QCDQCastWeightQuantProxyHandlerMixin, ABC):
handled_layer = DecoupledWeightQuantProxyFromInjector
_export_q_node = False

def symbolic_execution(self, x: Tensor):
out, scale, zero_point, bit_width = super().symbolic_execution(x)
# Return post-rounding scale and zero-point in place of pre-rounding as a placeholder
return out, scale, zero_point, scale, zero_point, bit_width


class CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin(
CDQCastDecoupledWeightQuantProxyHandlerMixin, ABC):
class QCDQCastDecoupledWeightQuantWithInputProxyHandlerMixin(
QCDQCastDecoupledWeightQuantProxyHandlerMixin, ABC):
handled_layer = DecoupledWeightQuantWithInputProxyFromInjector
_export_q_node = False

def validate(self, module):
assert not self._export_q_node, "This proxy requires to export integer weights"

def symbolic_execution(self, x: Tensor, input_bit_width: torch.Tensor, input_is_signed: bool):
return super().symbolic_execution(x)
Expand Down
27 changes: 14 additions & 13 deletions src/brevitas/export/onnx/standard/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
import torch

from brevitas.export.common.handler.qcdq import CDQCastBiasQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQCastDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import \
CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQCastMixin
from brevitas.export.common.handler.qcdq import CDQCastWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import DQCastMixin
from brevitas.export.common.handler.qcdq import DynamicQDQCastActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import DynamicQMixin
from brevitas.export.common.handler.qcdq import QCDQCastActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import \
QCDQCastDecoupledWeightQuantWithInputProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastTruncQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QMixin
from brevitas.export.onnx.handler import ONNXBaseHandler
from brevitas.export.onnx.handler import QuantLSTMLayerHandler
Expand Down Expand Up @@ -71,7 +71,8 @@ def validate(self, module):
super().validate(module)
# ONNX QuantizeLinear supports only 8b output with round to nearest even.
# Below 8b quantization is supported through clipping.
assert module.rounding_mode.upper() == 'ROUND', 'Only round to nearest even supported'
if getattr(self, '_export_q_node', True):
assert module.rounding_mode.upper() == 'ROUND', 'Only round to nearest even supported'
self.validate_8b_bit_width(module.bit_width(), le_then=True)

def quantize_fn(self, x, scale, zero_point, dtype, axis):
Expand Down Expand Up @@ -109,20 +110,20 @@ def quantize_fn(self, x, dtype):
return DynamicQuantizeLinearFn.apply(x, dtype)


class StdCDQCastONNXWeightQuantProxyHandler(StdCDQCastONNXMixin,
CDQCastWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
class StdQCDQCastONNXWeightQuantProxyHandler(StdQCDQCastONNXMixin,
QCDQCastWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdCDQCastONNXDecoupledWeightQuantProxyHandler(StdCDQCastONNXMixin,
CDQCastDecoupledWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
class StdQCDQCastONNXDecoupledWeightQuantProxyHandler(StdQCDQCastONNXMixin,
QCDQCastDecoupledWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdCDQCastONNXDecoupledWeightQuantWithInputProxyHandler(
StdCDQCastONNXMixin, CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin,
class StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler(
StdQCDQCastONNXMixin, QCDQCastDecoupledWeightQuantWithInputProxyHandlerMixin,
ONNXBaseHandler):
pass

Expand Down
23 changes: 17 additions & 6 deletions src/brevitas/export/onnx/standard/qcdq/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
from ..function import QuantizeLinearFn
from ..manager import StdONNXBaseManager
from .handler import StdCDQCastONNXBiasQuantProxyHandler
from .handler import StdCDQCastONNXDecoupledWeightQuantProxyHandler
from .handler import StdCDQCastONNXDecoupledWeightQuantWithInputProxyHandler
from .handler import StdCDQCastONNXWeightQuantProxyHandler
from .handler import StdDynamicQDQCastONNXActQuantProxyHandler
from .handler import StdQCDQCastONNXActQuantProxyHandler
from .handler import StdQCDQCastONNXDecoupledWeightQuantProxyHandler
from .handler import StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler
from .handler import StdQCDQCastONNXQuantLSTMLayerHandler
from .handler import StdQCDQCastONNXTruncQuantProxyHandler
from .handler import StdQCDQCastONNXWeightQuantProxyHandler


class StdQCDQONNXManager(StdONNXBaseManager):
Expand All @@ -35,13 +35,13 @@ class StdQCDQONNXManager(StdONNXBaseManager):
"eliminate_unused_initializer"]

handlers = [
StdCDQCastONNXWeightQuantProxyHandler,
StdQCDQCastONNXWeightQuantProxyHandler,
StdCDQCastONNXBiasQuantProxyHandler,
StdQCDQCastONNXActQuantProxyHandler,
StdCDQCastONNXDecoupledWeightQuantProxyHandler,
StdQCDQCastONNXDecoupledWeightQuantProxyHandler,
StdDynamicQDQCastONNXActQuantProxyHandler,
StdQCDQCastONNXTruncQuantProxyHandler,
StdCDQCastONNXDecoupledWeightQuantWithInputProxyHandler,
StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler,
StdQCDQCastONNXQuantLSTMLayerHandler]

custom_fns = [
Expand All @@ -61,3 +61,14 @@ def set_export_mode(cls, model: Module, enabled: bool):
def set_export_handler(cls, module: Module):
_set_proxy_export_handler(cls, module)
_set_recurrent_layer_export_handler(cls, module)

@classmethod
def export_onnx(cls, *args, export_weight_q_node: bool = False, **kwargs):
cls.change_weight_export(export_weight_q_node)
super().export_onnx(*args, **kwargs)

@classmethod
def change_weight_export(cls, export_weight_q_node: bool = False):
for handler in cls.handlers:
if hasattr(handler, '_export_q_node'):
handler._export_weight_q_node = export_weight_q_node
28 changes: 15 additions & 13 deletions src/brevitas/export/torch/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

from brevitas.export.common.handler.base import BaseHandler
from brevitas.export.common.handler.qcdq import CDQCastBiasQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQCastDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import \
CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQCastMixin
from brevitas.export.common.handler.qcdq import CDQCastWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import DQCastMixin
from brevitas.export.common.handler.qcdq import QCDQCastActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import \
QCDQCastDecoupledWeightQuantWithInputProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastTruncQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQCastWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QMixin


Expand Down Expand Up @@ -79,7 +79,8 @@ def int32_dtype(cls):

def validate(self, module):
super().validate(module)
assert module.rounding_mode.upper() == 'ROUND', 'Only round to nearest even supported'
if getattr(self, '_export_q_node', True):
assert module.rounding_mode.upper() == 'ROUND', 'Only round to nearest even supported'

def quantize_fn(self, x, scale, zero_point, dtype, axis):
if axis is None:
Expand All @@ -95,28 +96,29 @@ def forward(self, *args, **kwargs):
return self.symbolic_execution(*args, **kwargs)


class TorchCDQCastWeightQuantProxyHandler(TorchCDQCastMixin,
CDQCastWeightQuantProxyHandlerMixin,
TorchQCDQHandler):
class TorchQCDQCastWeightQuantProxyHandler(TorchQCDQCastMixin,
QCDQCastWeightQuantProxyHandlerMixin,
TorchQCDQHandler):

@classmethod
def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
clip_args = super().int_clip_symbolic_kwargs(narrow, signed, bit_width)
return _itemize_clip_bounds(clip_args)


class TorchCDQCastDecoupledWeightQuantProxyHandler(TorchCDQCastMixin,
CDQCastDecoupledWeightQuantProxyHandlerMixin,
TorchQCDQHandler):
class TorchQCDQCastDecoupledWeightQuantProxyHandler(TorchQCDQCastMixin,
QCDQCastDecoupledWeightQuantProxyHandlerMixin,
TorchQCDQHandler):

@classmethod
def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
clip_args = super().int_clip_symbolic_kwargs(narrow, signed, bit_width)
return _itemize_clip_bounds(clip_args)


class TorchCDQCastDecoupledWeightQuantWithInputProxyHandler(
TorchCDQCastMixin, CDQCastDecoupledWeightQuantWithInputProxyHandlerMixin, TorchQCDQHandler):
class TorchQCDQCastDecoupledWeightQuantWithInputProxyHandler(
TorchQCDQCastMixin, QCDQCastDecoupledWeightQuantWithInputProxyHandlerMixin,
TorchQCDQHandler):

@classmethod
def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
Expand Down
Loading
Loading