diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index a75e15b45..ee704a715 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -372,6 +372,7 @@ def __init__( self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module() self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() + self.init_done: bool = brevitas.jit.Attribute(False, bool) @brevitas.jit.script_method def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor: @@ -394,6 +395,7 @@ def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor: self.counter = new_counter return abs_binary_sign_grad(clamped_stats / threshold) elif self.counter == self.collect_stats_steps: + self.init_done = True self.restrict_inplace_preprocess(self.buffer) inplace_tensor_mul(self.value.detach(), self.buffer) threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold)) @@ -415,18 +417,18 @@ def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Te # Threshold division handled inside the training_forward return self.training_forward(stats_input, threshold) else: - if self.counter <= self.collect_stats_steps: - out = self.buffer + if not self.init_done: + self.init_done = True # No clamping is necessary since statistics are already clamped in training_forward - out = self.restrict_scaling_pre(out) - else: - out = self.value + self.restrict_inplace_preprocess(self.buffer) + inplace_tensor_mul(self.value.detach(), self.buffer) + out = self.value threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold)) out = self.restrict_scaling(out) out = out / threshold # We can clamp after restrict val since the learned parameter is already in log-domain out = abs_binary_sign_grad(self.clamp_scaling(out)) - return out + return out def state_dict(self, destination=None, prefix='', keep_vars=False): output_dict = super(ParameterFromRuntimeStatsScaling, self).state_dict(