From f3f75906855bde4ab9ad0a97abb16c138688d2b7 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 25 Nov 2024 14:19:16 +0000 Subject: [PATCH] Fix --- src/brevitas/core/scaling/standalone.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index 3afed108a..c76f3b769 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -395,7 +395,6 @@ def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor: self.counter = new_counter return abs_binary_sign_grad(clamped_stats / threshold) elif self.counter == self.collect_stats_steps: - self.init_done = True self.restrict_inplace_preprocess(self.buffer) inplace_tensor_mul(self.value.detach(), self.buffer) threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold)) @@ -413,8 +412,7 @@ def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor: def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Tensor: if threshold is None: threshold = torch.ones(1).type_as(stats_input) - if self.training: - self.init_done = False + if self.training and not self.init_done: # Threshold division handled inside the training_forward return self.training_forward(stats_input, threshold) else: