diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index ee704a715..3afed108a 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -414,6 +414,7 @@ def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Te if threshold is None: threshold = torch.ones(1).type_as(stats_input) if self.training: + self.init_done = False # Threshold division handled inside the training_forward return self.training_forward(stats_input, threshold) else: