@@ -357,12 +357,19 @@ class PerChannelPreNorm(ExtendedInjector):
357
357
358
358
class AccumulatorAwarePerChannelPreNorm (PerChannelPreNorm ):
359
359
360
- @value
361
- def accumulator_bit_width_impl (accumulator_bit_width ):
362
- return BitWidthStatefulConst (accumulator_bit_width )
363
-
364
360
pre_scaling_impl = AccumulatorAwareParameterPreScaling
365
- accumulator_bit_width = 32 # default maximum accumulator width is 32 bits
361
+ accumulator_bit_width = (this << 1 ).accumulator_bit_width
362
+ accumulator_bit_width_impl = (this << 1 ).accumulator_bit_width_impl
363
+
364
+
365
+ class AccumulatorAwareZeroCenterPerChannelPreNorm (AccumulatorAwarePerChannelPreNorm ):
366
+
367
+ pre_scaling_impl = AccumulatorAwareZeroCenterParameterPreScaling
368
+ pre_zero_point_impl = PreZeroCenterZeroPoint
369
+ pre_zero_point_shape = this .scaling_shape # TODO: decouple zero_point from scaling
370
+ pre_zero_point_stats_input_view_shape_impl = this .scaling_stats_input_view_shape_impl
371
+ stats_reduce_dim = (this << 1 ).stats_reduce_dim
372
+ scaling_shape = (this << 1 ).scaling_shape
366
373
367
374
368
375
class SolvePostScaleGranularity (ExtendedInjector ):
@@ -457,10 +464,12 @@ class AccumulatorAwareWeightQuant(WeightNormPerChannelFloatDecoupled):
457
464
per_channel_pre_norm = AccumulatorAwarePerChannelPreNorm
458
465
normalize_stats_impl = PerChannelL1Norm .normalize_stats_impl # required to align with derivations in paper
459
466
float_to_int_impl = RoundToZeroSte # required to ensure no upwards rounding violates constraints
467
+ accumulator_bit_width = 32 # default maximum accumulator width is 32 bits
460
468
461
469
@value
462
- def accumulator_bit_width ():
463
- return this .per_channel_pre_norm .accumulator_bit_width
470
+ def accumulator_bit_width_impl (accumulator_bit_width ):
471
+ return BitWidthStatefulConst (accumulator_bit_width )
472
+
464
473
465
474
class AccumulatorAwareZeroCenterWeightQuant (AccumulatorAwareWeightQuant ):
466
475
"""Experimental zero-centered accumulator-aware weight quantized based on:
@@ -470,10 +479,7 @@ class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant):
470
479
(1) added zero-centering constraint on the weights (i.e., `PreZeroCenterZeroPoint`)
471
480
(2) a more relaxed l1-norm bound that is derived in the referenced paper
472
481
"""
473
- pre_scaling_impl = AccumulatorAwareZeroCenterParameterPreScaling
474
- pre_zero_point_impl = PreZeroCenterZeroPoint
475
- pre_zero_point_shape = this .scaling_shape # TODO: decouple zero_point from scaling
476
- pre_zero_point_stats_input_view_shape_impl = this .scaling_stats_input_view_shape_impl
482
+ per_channel_pre_norm = AccumulatorAwareZeroCenterPerChannelPreNorm
477
483
478
484
479
485
class MSESubInjectorBase (ExtendedInjector ):
0 commit comments