Skip to content

Commit

Permalink
fix some errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 23, 2024
1 parent 25dd388 commit eb6e108
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,12 +357,19 @@ class PerChannelPreNorm(ExtendedInjector):

class AccumulatorAwarePerChannelPreNorm(PerChannelPreNorm):

@value
def accumulator_bit_width_impl(accumulator_bit_width):
return BitWidthStatefulConst(accumulator_bit_width)

pre_scaling_impl = AccumulatorAwareParameterPreScaling
accumulator_bit_width = 32 # default maximum accumulator width is 32 bits
accumulator_bit_width = (this << 1).accumulator_bit_width
accumulator_bit_width_impl = (this << 1).accumulator_bit_width_impl


class AccumulatorAwareZeroCenterPerChannelPreNorm(AccumulatorAwarePerChannelPreNorm):

pre_scaling_impl = AccumulatorAwareZeroCenterParameterPreScaling
pre_zero_point_impl = PreZeroCenterZeroPoint
pre_zero_point_shape = this.scaling_shape # TODO: decouple zero_point from scaling
pre_zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl
stats_reduce_dim = (this << 1).stats_reduce_dim
scaling_shape = (this << 1).scaling_shape


class SolvePostScaleGranularity(ExtendedInjector):
Expand Down Expand Up @@ -457,10 +464,12 @@ class AccumulatorAwareWeightQuant(WeightNormPerChannelFloatDecoupled):
per_channel_pre_norm = AccumulatorAwarePerChannelPreNorm
normalize_stats_impl = PerChannelL1Norm.normalize_stats_impl # required to align with derivations in paper
float_to_int_impl = RoundToZeroSte # required to ensure no upwards rounding violates constraints
accumulator_bit_width = 32 # default maximum accumulator width is 32 bits

@value
def accumulator_bit_width():
return this.per_channel_pre_norm.accumulator_bit_width
def accumulator_bit_width_impl(accumulator_bit_width):
return BitWidthStatefulConst(accumulator_bit_width)


class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant):
"""Experimental zero-centered accumulator-aware weight quantized based on:
Expand All @@ -470,10 +479,7 @@ class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant):
(1) added zero-centering constraint on the weights (i.e., `PreZeroCenterZeroPoint`)
(2) a more relaxed l1-norm bound that is derived in the referenced paper
"""
pre_scaling_impl = AccumulatorAwareZeroCenterParameterPreScaling
pre_zero_point_impl = PreZeroCenterZeroPoint
pre_zero_point_shape = this.scaling_shape # TODO: decouple zero_point from scaling
pre_zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl
per_channel_pre_norm = AccumulatorAwareZeroCenterPerChannelPreNorm


class MSESubInjectorBase(ExtendedInjector):
Expand Down

0 comments on commit eb6e108

Please sign in to comment.