diff --git a/src/brevitas/core/scaling/pre_scaling.py b/src/brevitas/core/scaling/pre_scaling.py index 396b747fa..632242507 100644 --- a/src/brevitas/core/scaling/pre_scaling.py +++ b/src/brevitas/core/scaling/pre_scaling.py @@ -15,8 +15,7 @@ from brevitas.core.stats.stats_wrapper import _Stats from brevitas.function import abs_binary_sign_grad -__all__ = [ - "ParameterPreScalingWeightNorm",] +__all__ = ["ParameterPreScalingWeightNorm", "AccumulatorAwareParameterPreScaling"] class ParameterPreScalingWeightNorm(brevitas.jit.ScriptModule): diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index 3c82a3d5e..f72045dab 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -352,7 +352,7 @@ def scaling_init(scaling_init_impl, bit_width): tensor_clamp_impl = TensorClamp scaling_impl = ParameterScaling scaling_init_impl = StatsFromParameterScaling - restrict_scaling_impl = FloatRestrictValue + restrict_scaling_impl = LogFloatRestrictValue scaling_stats_impl = AbsMax pre_scaling_impl = ParameterPreScalingWeightNorm restrict_pre_scaling_impl = LogFloatRestrictValue @@ -395,7 +395,6 @@ def accumulator_bit_width_impl(accumulator_bit_width): proxy_class = DecoupledWeightQuantWithInputProxyFromInjector tensor_quant = DecoupledRescalingIntQuantWithInput pre_scaling_impl = AccumulatorAwareParameterPreScaling - pre_scaling_min_val = 1e-8 accumulator_bit_width = 32 # default maximum accumulator width is 32 bits normalize_stats_impl = L1Norm # required to align with derivations in paper float_to_int_impl = RoundToZeroSte # required to ensure no upwards rounding violates constraints diff --git a/src/brevitas_examples/super_resolution/models/common.py b/src/brevitas_examples/super_resolution/models/common.py index d3022d089..16ba143c5 100644 --- a/src/brevitas_examples/super_resolution/models/common.py +++ b/src/brevitas_examples/super_resolution/models/common.py @@ -6,6 +6,7 @@ from torch import Tensor import torch.nn as nn +from brevitas.core.restrict_val import FloatRestrictValue from brevitas.core.restrict_val import RestrictValueType from brevitas.core.scaling import ScalingImplType import brevitas.nn as qnn @@ -25,6 +26,7 @@ class CommonIntWeightPerChannelQuant(Int8WeightPerTensorFloat): class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant): + restrict_scaling_impl = FloatRestrictValue # backwards compatibility pre_scaling_min_val = 1e-10 scaling_min_val = 1e-10