Skip to content

Commit

Permalink
Fix (scaling/standalone): better switch from runtime stats to param
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 21, 2024
1 parent abf4a40 commit c6d58a5
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def __init__(
self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module()
self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module()
self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module()
self.init_done: bool = brevitas.jit.Attribute(False, bool)

@brevitas.jit.script_method
def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor:
Expand All @@ -394,6 +395,7 @@ def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor:
self.counter = new_counter
return abs_binary_sign_grad(clamped_stats / threshold)
elif self.counter == self.collect_stats_steps:
self.init_done = True
self.restrict_inplace_preprocess(self.buffer)
inplace_tensor_mul(self.value.detach(), self.buffer)
threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold))
Expand All @@ -415,18 +417,18 @@ def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Te
# Threshold division handled inside the training_forward
return self.training_forward(stats_input, threshold)
else:
if self.counter <= self.collect_stats_steps:
out = self.buffer
if not self.init_done:
self.init_done = True
# No clamping is necessary since statistics are already clamped in training_forward
out = self.restrict_scaling_pre(out)
else:
out = self.value
self.restrict_inplace_preprocess(self.buffer)
inplace_tensor_mul(self.value.detach(), self.buffer)
out = self.value
threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold))
out = self.restrict_scaling(out)
out = out / threshold
# We can clamp after restrict val since the learned parameter is already in log-domain
out = abs_binary_sign_grad(self.clamp_scaling(out))
return out
return out

def state_dict(self, destination=None, prefix='', keep_vars=False):
output_dict = super(ParameterFromRuntimeStatsScaling, self).state_dict(
Expand Down

0 comments on commit c6d58a5

Please sign in to comment.