Skip to content

Commit

Permalink
review
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 18, 2024
1 parent 0fcf1e6 commit 0947d85
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 11 deletions.
4 changes: 2 additions & 2 deletions src/brevitas/export/onnx/standard/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def int32_dtype(cls):

def validate(self, module):
super().validate(module)
# ONNX QuantizeLinear supports only 8b output with round to nearest even.
# Below 8b quantization is supported through clipping.
# ONNX DynamicQuantizeLinear supports only 8b output with round to nearest even.
assert module.rounding_mode.upper() == 'ROUND', 'Only round to nearest even supported'
# Below 8b quantization is not supported.
self.validate_8b_bit_width(module.bit_width(), le_then=False)

def quantize_fn(self, x, dtype):
Expand Down
8 changes: 6 additions & 2 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
from brevitas_examples.common.generative.quantizers import Int8ActPerRowFloat
from brevitas_examples.common.generative.quantizers import Int8ActPerRowFloatMSE
from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant
from brevitas_examples.common.generative.quantizers import ShiftedUint8ActDynamicPerGroupFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8ActDynamicPerRowFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8ActDynamicPerTensorFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8ActPerRowFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8ActPerRowFloatMSE
Expand Down Expand Up @@ -112,9 +114,11 @@
'sym': Int8ActDynamicPerTensorFloat,
'asym': ShiftedUint8ActDynamicPerTensorFloat},
'per_row': {
'sym': Int8ActDynamicPerRowFloat},
'sym': Int8ActDynamicPerRowFloat,
'asym': ShiftedUint8ActDynamicPerRowFloat},
'per_group': {
'sym': Int8ActDynamicPerGroupFloat},}}}},
'sym': Int8ActDynamicPerGroupFloat,
'asym': ShiftedUint8ActDynamicPerGroupFloat},}}}},
'float': {
'static': {
'float_scale': {
Expand Down
41 changes: 34 additions & 7 deletions src/brevitas_examples/common/generative/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def reshaped_scaling_shape(module):
block_size = None


class ActDynamicProxyMixin(ExtendedInjector):
proxy_class = DynamicActQuantProxyFromInjector


class IntWeightSymmetricGroupQuant(WeightSymmetricGroupQuantMixin, Int8WeightPerChannelFloat):
"""
Block / group / vector signed symmetric int weight quantizer with float scales.
Expand Down Expand Up @@ -130,11 +134,32 @@ class Int8ActDynamicPerTensorFloat(Int8ActPerTensorFloat):
scaling_stats_op = 'max'


class ShiftedUint8ActDynamicPerTensorFloat(ShiftedUint8ActPerTensorFloat):
class Int8ActDynamicPerRowFloat(Int8ActPerRowFloat, ActDynamicProxyMixin):
"""
Symmetric quantizer with per tensor dynamic scale.
Symmetric quantizer with per row dynamic scale.
"""
scaling_impl = RuntimeDynamicStatsScaling
scaling_stats_input_view_shape_impl = OverBatchOverOutputChannelView
scaling_stats_op = 'max'


class Int8ActDynamicPerGroupFloat(Int8ActPerRowFloat, ActDynamicProxyMixin):
"""
Symmetric quantizer with per group scale.
"""
scaling_impl = RuntimeDynamicGroupStatsScaling
keepdim = True
scaling_stats_op = 'max'

@value
def stats_reduce_dim(group_dim):
return group_dim + 1


class ShiftedUint8ActDynamicPerTensorFloat(ShiftedUint8ActPerTensorFloat, ActDynamicProxyMixin):
"""
Asymmetric quantizer with per tensor dynamic scale.
"""
proxy_class = DynamicActQuantProxyFromInjector
scaling_impl = RuntimeDynamicStatsScaling
scaling_stats_input_view_shape_impl = OverTensorView
scaling_stats_op = 'max'
Expand All @@ -143,22 +168,24 @@ class ShiftedUint8ActDynamicPerTensorFloat(ShiftedUint8ActPerTensorFloat):
stats_reduce_dim = 0


class Int8ActDynamicPerRowFloat(Int8ActPerRowFloat):
class ShiftedUint8ActDynamicPerRowFloat(ShiftedUint8ActPerRowFloat, ActDynamicProxyMixin):
"""
Symmetric quantizer with per row dynamic scale.
Asymmetric quantizer with per row dynamic scale.
"""
scaling_impl = RuntimeDynamicStatsScaling
scaling_stats_input_view_shape_impl = OverBatchOverOutputChannelView
scaling_stats_op = 'max'
zero_point_stats_impl = NegativeMinOrZero


class Int8ActDynamicPerGroupFloat(Int8ActPerRowFloat):
class ShiftedUint8ActDynamicPerGroupFloat(ShiftedUint8ActPerRowFloat, ActDynamicProxyMixin):
"""
Symmetric quantizer with per group scale.
Asymmetric quantizer with per group dynamic scale.
"""
scaling_impl = RuntimeDynamicGroupStatsScaling
keepdim = True
scaling_stats_op = 'max'
zero_point_stats_impl = NegativeMinOrZero

@value
def stats_reduce_dim(group_dim):
Expand Down

0 comments on commit 0947d85

Please sign in to comment.