From 216e68ee21c51c9b06db15158f6e8d6a2f6727a7 Mon Sep 17 00:00:00 2001 From: icolbert Date: Tue, 10 Oct 2023 14:45:07 -0700 Subject: [PATCH 1/4] Update A2Q defaults --- src/brevitas/quant/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index 3c82a3d5e..f72045dab 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -352,7 +352,7 @@ def scaling_init(scaling_init_impl, bit_width): tensor_clamp_impl = TensorClamp scaling_impl = ParameterScaling scaling_init_impl = StatsFromParameterScaling - restrict_scaling_impl = FloatRestrictValue + restrict_scaling_impl = LogFloatRestrictValue scaling_stats_impl = AbsMax pre_scaling_impl = ParameterPreScalingWeightNorm restrict_pre_scaling_impl = LogFloatRestrictValue @@ -395,7 +395,6 @@ def accumulator_bit_width_impl(accumulator_bit_width): proxy_class = DecoupledWeightQuantWithInputProxyFromInjector tensor_quant = DecoupledRescalingIntQuantWithInput pre_scaling_impl = AccumulatorAwareParameterPreScaling - pre_scaling_min_val = 1e-8 accumulator_bit_width = 32 # default maximum accumulator width is 32 bits normalize_stats_impl = L1Norm # required to align with derivations in paper float_to_int_impl = RoundToZeroSte # required to ensure no upwards rounding violates constraints From 78ec9be3bf3cf55d8bf186a661ed4aa5a6850ca5 Mon Sep 17 00:00:00 2001 From: icolbert Date: Tue, 10 Oct 2023 14:45:28 -0700 Subject: [PATCH 2/4] Fixing backwards compatibility --- src/brevitas_examples/super_resolution/models/common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/brevitas_examples/super_resolution/models/common.py b/src/brevitas_examples/super_resolution/models/common.py index d3022d089..124ba93e9 100644 --- a/src/brevitas_examples/super_resolution/models/common.py +++ b/src/brevitas_examples/super_resolution/models/common.py @@ -7,6 +7,7 @@ import torch.nn as nn from brevitas.core.restrict_val import RestrictValueType +from brevitas.core.restrict_val import FloatRestrictValue from brevitas.core.scaling import ScalingImplType import brevitas.nn as qnn from brevitas.nn.quant_layer import WeightQuantType @@ -25,6 +26,7 @@ class CommonIntWeightPerChannelQuant(Int8WeightPerTensorFloat): class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant): + restrict_scaling_impl = FloatRestrictValue # backwards compatibility pre_scaling_min_val = 1e-10 scaling_min_val = 1e-10 From 6cdadcd5329186888d069b7fd87f486ed79c8e06 Mon Sep 17 00:00:00 2001 From: icolbert Date: Thu, 12 Oct 2023 07:52:19 -0700 Subject: [PATCH 3/4] Pre-commit fixes --- src/brevitas_examples/super_resolution/models/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/super_resolution/models/common.py b/src/brevitas_examples/super_resolution/models/common.py index 124ba93e9..16ba143c5 100644 --- a/src/brevitas_examples/super_resolution/models/common.py +++ b/src/brevitas_examples/super_resolution/models/common.py @@ -6,8 +6,8 @@ from torch import Tensor import torch.nn as nn -from brevitas.core.restrict_val import RestrictValueType from brevitas.core.restrict_val import FloatRestrictValue +from brevitas.core.restrict_val import RestrictValueType from brevitas.core.scaling import ScalingImplType import brevitas.nn as qnn from brevitas.nn.quant_layer import WeightQuantType From b5f0e9eada70e3da677bf2660d2d38bf5709d079 Mon Sep 17 00:00:00 2001 From: icolbert Date: Thu, 2 Nov 2023 11:06:25 -0700 Subject: [PATCH 4/4] Update pre_scaling.py --- src/brevitas/core/scaling/pre_scaling.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/brevitas/core/scaling/pre_scaling.py b/src/brevitas/core/scaling/pre_scaling.py index 396b747fa..632242507 100644 --- a/src/brevitas/core/scaling/pre_scaling.py +++ b/src/brevitas/core/scaling/pre_scaling.py @@ -15,8 +15,7 @@ from brevitas.core.stats.stats_wrapper import _Stats from brevitas.function import abs_binary_sign_grad -__all__ = [ - "ParameterPreScalingWeightNorm",] +__all__ = ["ParameterPreScalingWeightNorm", "AccumulatorAwareParameterPreScaling"] class ParameterPreScalingWeightNorm(brevitas.jit.ScriptModule):