From 63cc19dcd49d4cc5c26beb883530e9b53ff25d9d Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 21 Nov 2024 14:50:08 +0000 Subject: [PATCH] fix --- src/brevitas/core/scaling/standalone.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index ee704a715..3afed108a 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -414,6 +414,7 @@ def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Te if threshold is None: threshold = torch.ones(1).type_as(stats_input) if self.training: + self.init_done = False # Threshold division handled inside the training_forward return self.training_forward(stats_input, threshold) else: