Skip to content

Commit

Permalink
threshold fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 25, 2024
1 parent 9b8812a commit a7d5e67
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/brevitas/core/scaling/pre_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def forward(self, weights: Tensor) -> Tensor:
weights = self.stats_input_view_shape_impl(weights)
d_w = self.stats(weights) # denominator for weight normalization
g = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) # g
s = self.scaling_impl(weights) # s
s = self.scaling_impl(weights, torch.tensor(1.).type_as(weights)) # s
value = (s * d_w) / g
return value

Expand Down Expand Up @@ -184,7 +184,7 @@ def calc_max_l1_norm(self, input_bit_width: Tensor, input_is_signed: bool) -> Te
def inner_forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: bool):
weights = self.stats_input_view_shape_impl(weights)
d_w = self.stats(weights) # denominator for weight normalization
s = self.scaling_impl(weights) # s
s = self.scaling_impl(weights, torch.tensor(1.).type_as(weights)) # s
g = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) # g
T = self.calc_max_l1_norm(input_bit_width, input_is_signed) # T / s
g = torch.clamp_max(g / s, T)
Expand Down

0 comments on commit a7d5e67

Please sign in to comment.