Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 25, 2024
1 parent 63cc19d commit f3f7590
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,6 @@ def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor:
self.counter = new_counter
return abs_binary_sign_grad(clamped_stats / threshold)
elif self.counter == self.collect_stats_steps:
self.init_done = True
self.restrict_inplace_preprocess(self.buffer)
inplace_tensor_mul(self.value.detach(), self.buffer)
threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold))
Expand All @@ -413,8 +412,7 @@ def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor:
def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(stats_input)
if self.training:
self.init_done = False
if self.training and not self.init_done:
# Threshold division handled inside the training_forward
return self.training_forward(stats_input, threshold)
else:
Expand Down

0 comments on commit f3f7590

Please sign in to comment.