Skip to content

Commit

Permalink
Fix (base): Updating A2Q defaults (#718)
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert authored Nov 6, 2023
1 parent e410ff3 commit 4d51f18
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
3 changes: 1 addition & 2 deletions src/brevitas/core/scaling/pre_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from brevitas.core.stats.stats_wrapper import _Stats
from brevitas.function import abs_binary_sign_grad

__all__ = [
"ParameterPreScalingWeightNorm",]
__all__ = ["ParameterPreScalingWeightNorm", "AccumulatorAwareParameterPreScaling"]


class ParameterPreScalingWeightNorm(brevitas.jit.ScriptModule):
Expand Down
3 changes: 1 addition & 2 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def scaling_init(scaling_init_impl, bit_width):
tensor_clamp_impl = TensorClamp
scaling_impl = ParameterScaling
scaling_init_impl = StatsFromParameterScaling
restrict_scaling_impl = FloatRestrictValue
restrict_scaling_impl = LogFloatRestrictValue
scaling_stats_impl = AbsMax
pre_scaling_impl = ParameterPreScalingWeightNorm
restrict_pre_scaling_impl = LogFloatRestrictValue
Expand Down Expand Up @@ -395,7 +395,6 @@ def accumulator_bit_width_impl(accumulator_bit_width):
proxy_class = DecoupledWeightQuantWithInputProxyFromInjector
tensor_quant = DecoupledRescalingIntQuantWithInput
pre_scaling_impl = AccumulatorAwareParameterPreScaling
pre_scaling_min_val = 1e-8
accumulator_bit_width = 32 # default maximum accumulator width is 32 bits
normalize_stats_impl = L1Norm # required to align with derivations in paper
float_to_int_impl = RoundToZeroSte # required to ensure no upwards rounding violates constraints
Expand Down
2 changes: 2 additions & 0 deletions src/brevitas_examples/super_resolution/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import Tensor
import torch.nn as nn

from brevitas.core.restrict_val import FloatRestrictValue
from brevitas.core.restrict_val import RestrictValueType
from brevitas.core.scaling import ScalingImplType
import brevitas.nn as qnn
Expand All @@ -25,6 +26,7 @@ class CommonIntWeightPerChannelQuant(Int8WeightPerTensorFloat):


class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant):
restrict_scaling_impl = FloatRestrictValue # backwards compatibility
pre_scaling_min_val = 1e-10
scaling_min_val = 1e-10

Expand Down

0 comments on commit 4d51f18

Please sign in to comment.