Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 7, 2024
1 parent e399b7f commit b67fed4
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,11 @@ def __init__(
def forward(self, placeholder: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(placeholder)
value = self.value() / threshold
restricted_value = self.restrict_clamp_scaling(value)
# We first apply any restriction to scaling
# For IntQuant, this is no-op, retrocompatible.
threshold = self.restrict_clamp_scaling(self.restrict_init_module(threshold))
restricted_value = self.restrict_clamp_scaling(self.value())
restricted_value = restricted_value / threshold
return restricted_value


Expand Down Expand Up @@ -145,6 +148,7 @@ def __init__(
scaling_init = torch.tensor(scaling_init, dtype=dtype, device=device)
if restrict_scaling_impl is not None:
scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init)
self.restrict_init_module = restrict_scaling_impl.restrict_init_module()
if scaling_init.shape == SCALAR_SHAPE and scaling_shape is not None:
scaling_init = torch.full(scaling_shape, scaling_init, dtype=dtype, device=device)
self.value = Parameter(scaling_init)
Expand All @@ -154,8 +158,11 @@ def __init__(
def forward(self, placeholder: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(placeholder)
value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value / threshold))
return value
# We first apply any restriction to scaling
# For IntQuant, this is no-op, retrocompatible.
threshold = self.restrict_clamp_scaling(self.restrict_init_module(threshold))
value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value))
return value / threshold

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
Expand Down Expand Up @@ -363,7 +370,7 @@ def training_forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tens
self.restrict_inplace_preprocess(self.buffer)
inplace_tensor_mul(self.value.detach(), self.buffer)
threshold = self.restrict_preprocess(threshold)
value = self.restrict_scaling_impl.combine_stats_threshold(value, threshold)
value = self.restrict_scaling_impl.combine_stats_threshold(self.value, threshold)
self.counter = self.counter + 1
return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value)))
else:
Expand Down

0 comments on commit b67fed4

Please sign in to comment.