Skip to content

Commit

Permalink
fix JIT
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 30, 2024
1 parent 630f32e commit 8e168d2
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def __init__(
def forward(
self, ignored: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(stats_input)
threshold = torch.ones(1).type_as(ignored)
# Threshold division must happen after we update self.value, but before we apply restrict_preproces
# This is because we don't want to store a parameter dependant on a runtime value (threshold)
# And because restrict needs to happen after we divide by threshold
Expand Down Expand Up @@ -342,8 +342,7 @@ def __init__(
self.restrict_preprocess = Identity()

@brevitas.jit.script_method
def training_forward(
self, stats_input: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor:
def training_forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tensor:
# Threshold division must happen after we update self.value, but before we apply restrict_preproces
# This is because we don't want to store a parameter dependant on a runtime value (threshold)
# And because restrict needs to happen after we divide by threshold
Expand Down

0 comments on commit 8e168d2

Please sign in to comment.