Skip to content

Commit eb6e108

Browse files
committed
fix some errors
1 parent 25dd388 commit eb6e108

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

src/brevitas/quant/base.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -357,12 +357,19 @@ class PerChannelPreNorm(ExtendedInjector):
357357

358358
class AccumulatorAwarePerChannelPreNorm(PerChannelPreNorm):
359359

360-
@value
361-
def accumulator_bit_width_impl(accumulator_bit_width):
362-
return BitWidthStatefulConst(accumulator_bit_width)
363-
364360
pre_scaling_impl = AccumulatorAwareParameterPreScaling
365-
accumulator_bit_width = 32 # default maximum accumulator width is 32 bits
361+
accumulator_bit_width = (this << 1).accumulator_bit_width
362+
accumulator_bit_width_impl = (this << 1).accumulator_bit_width_impl
363+
364+
365+
class AccumulatorAwareZeroCenterPerChannelPreNorm(AccumulatorAwarePerChannelPreNorm):
366+
367+
pre_scaling_impl = AccumulatorAwareZeroCenterParameterPreScaling
368+
pre_zero_point_impl = PreZeroCenterZeroPoint
369+
pre_zero_point_shape = this.scaling_shape # TODO: decouple zero_point from scaling
370+
pre_zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl
371+
stats_reduce_dim = (this << 1).stats_reduce_dim
372+
scaling_shape = (this << 1).scaling_shape
366373

367374

368375
class SolvePostScaleGranularity(ExtendedInjector):
@@ -457,10 +464,12 @@ class AccumulatorAwareWeightQuant(WeightNormPerChannelFloatDecoupled):
457464
per_channel_pre_norm = AccumulatorAwarePerChannelPreNorm
458465
normalize_stats_impl = PerChannelL1Norm.normalize_stats_impl # required to align with derivations in paper
459466
float_to_int_impl = RoundToZeroSte # required to ensure no upwards rounding violates constraints
467+
accumulator_bit_width = 32 # default maximum accumulator width is 32 bits
460468

461469
@value
462-
def accumulator_bit_width():
463-
return this.per_channel_pre_norm.accumulator_bit_width
470+
def accumulator_bit_width_impl(accumulator_bit_width):
471+
return BitWidthStatefulConst(accumulator_bit_width)
472+
464473

465474
class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant):
466475
"""Experimental zero-centered accumulator-aware weight quantized based on:
@@ -470,10 +479,7 @@ class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant):
470479
(1) added zero-centering constraint on the weights (i.e., `PreZeroCenterZeroPoint`)
471480
(2) a more relaxed l1-norm bound that is derived in the referenced paper
472481
"""
473-
pre_scaling_impl = AccumulatorAwareZeroCenterParameterPreScaling
474-
pre_zero_point_impl = PreZeroCenterZeroPoint
475-
pre_zero_point_shape = this.scaling_shape # TODO: decouple zero_point from scaling
476-
pre_zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl
482+
per_channel_pre_norm = AccumulatorAwareZeroCenterPerChannelPreNorm
477483

478484

479485
class MSESubInjectorBase(ExtendedInjector):

0 commit comments

Comments
 (0)