diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index 5126e4f83..6a14f86cb 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -80,8 +80,11 @@ def __init__( def forward(self, placeholder: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor: if threshold is None: threshold = torch.ones(1).type_as(placeholder) - value = self.value() / threshold - restricted_value = self.restrict_clamp_scaling(value) + # We first apply any restriction to scaling + # For IntQuant, this is no-op, retrocompatible. + threshold = self.restrict_clamp_scaling(self.restrict_init_module(threshold)) + restricted_value = self.restrict_clamp_scaling(self.value()) + restricted_value = restricted_value / threshold return restricted_value @@ -145,6 +148,7 @@ def __init__( scaling_init = torch.tensor(scaling_init, dtype=dtype, device=device) if restrict_scaling_impl is not None: scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init) + self.restrict_init_module = restrict_scaling_impl.restrict_init_module() if scaling_init.shape == SCALAR_SHAPE and scaling_shape is not None: scaling_init = torch.full(scaling_shape, scaling_init, dtype=dtype, device=device) self.value = Parameter(scaling_init) @@ -154,8 +158,11 @@ def __init__( def forward(self, placeholder: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor: if threshold is None: threshold = torch.ones(1).type_as(placeholder) - value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value / threshold)) - return value + # We first apply any restriction to scaling + # For IntQuant, this is no-op, retrocompatible. + threshold = self.restrict_clamp_scaling(self.restrict_init_module(threshold)) + value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) + return value / threshold def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, @@ -363,7 +370,7 @@ def training_forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tens self.restrict_inplace_preprocess(self.buffer) inplace_tensor_mul(self.value.detach(), self.buffer) threshold = self.restrict_preprocess(threshold) - value = self.restrict_scaling_impl.combine_stats_threshold(value, threshold) + value = self.restrict_scaling_impl.combine_stats_threshold(self.value, threshold) self.counter = self.counter + 1 return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) else: