Skip to content

Commit

Permalink
Fix (scaling)!: clamp to avoid inf/nan in forward/backward (#1097)
Browse files Browse the repository at this point in the history
Breaking change: new models trained after this PR, especially using QAT/gradient-based PTQ might converge to different results
  • Loading branch information
Giuseppe5 authored Nov 19, 2024
1 parent 9b31212 commit abf4a40
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 3 deletions.
11 changes: 9 additions & 2 deletions src/brevitas/core/scaling/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import brevitas
import brevitas.config as config
from brevitas.core.function_wrapper import Identity
from brevitas.core.restrict_val import _ClampValue
from brevitas.core.restrict_val import _RestrictClampValue
from brevitas.core.restrict_val import FloatRestrictValue
from brevitas.core.stats import _ParameterListStats
Expand Down Expand Up @@ -95,6 +96,7 @@ def __init__(
restrict_value_impl=restrict_threshold_impl)
self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module()
self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module()
self.clamp_scaling = _ClampValue(scaling_min_val)

@brevitas.jit.script_method
def forward(
Expand All @@ -103,6 +105,8 @@ def forward(
threshold = torch.ones(1).type_as(stats)
threshold = self.restrict_threshold_pre(threshold)
threshold = self.restrict_clamp_threshold(threshold)
# Clamping avoids eventual log(0) with restrict_val
stats = self.clamp_scaling(stats)
stats = self.restrict_scaling_pre(stats)
stats = self.affine_rescaling(stats)
stats = self.restrict_clamp_scaling(stats)
Expand Down Expand Up @@ -218,6 +222,7 @@ def __init__(
)
self.restrict_threshold_pre = self.restrict_clamp_threshold.restrict_value_impl.restrict_init_module(
)
self.clamp_scaling = _ClampValue(scaling_min_val)

@brevitas.jit.script_method
def forward(
Expand All @@ -229,8 +234,10 @@ def forward(
stats_input_reshaped = self.input_view_impl(stats_input)
threshold = self.restrict_clamp_threshold(self.restrict_threshold_pre(threshold))
out = self.scaling_stats_impl(stats_input_reshaped)
# Apply log scaling
# Clamping avoids eventual log(0) with restrict_val
out = self.clamp_scaling(out)
# Apply restrict_value preprocess
out = self.restrict_scaling_pre(out)
# Scaling min val
# Apply restrict_value and clamping
out = self.restrict_clamp_scaling(out) / threshold
return out
9 changes: 8 additions & 1 deletion src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Te
# We first apply any restriction to scaling
# For IntQuant, this is no-op, retrocompatible.
threshold = self.restrict_clamp_threshold(self.restrict_threshold_pre(threshold))
# We can clamp after restrict val since the learned parameter is already in log-domain
value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value))
return value / threshold

Expand Down Expand Up @@ -234,6 +235,7 @@ def __init__(
device)
self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module()
self.restrict_inplace_scaling_pre = restrict_scaling_impl.restrict_init_inplace_module()
self.clamp_scaling = _ClampValue(scaling_min_val)

self.init_done: bool = brevitas.jit.Attribute(False, bool)
self.local_loss_mode: bool = brevitas.jit.Attribute(False, bool)
Expand All @@ -255,7 +257,10 @@ def forward(self, x: Tensor, threshold: Optional[Tensor] = None) -> Tensor:
# workaround to avoid find_ununsed_parameter=True in DDP
stats = stats + 0. * self.value
if self.local_loss_mode:
# Scaling implementation before/after restrict_val is performed in stats_scaling_impl
return self.stats_scaling_impl(stats, threshold)
# Clamping avoids eventual log(0) with restrict_val
stats = self.clamp_scaling(stats)
stats = self.restrict_inplace_scaling_pre(stats)
threshold = self.stats_scaling_impl.restrict_clamp_threshold(
self.restrict_threshold_pre(threshold))
Expand Down Expand Up @@ -412,12 +417,14 @@ def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Te
else:
if self.counter <= self.collect_stats_steps:
out = self.buffer
# No clamping is necessary since statistics are already clamped in training_forward
out = self.restrict_scaling_pre(out)
else:
out = self.value
threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold))
out = self.clamp_scaling(self.restrict_scaling(out))
out = self.restrict_scaling(out)
out = out / threshold
# We can clamp after restrict val since the learned parameter is already in log-domain
out = abs_binary_sign_grad(self.clamp_scaling(out))
return out

Expand Down
53 changes: 53 additions & 0 deletions tests/brevitas/core/test_runtime_scaling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch

from brevitas.core.function_wrapper.misc import Identity
from brevitas.core.restrict_val import PowerOfTwoRestrictValue
from brevitas.core.scaling.runtime import RuntimeDynamicGroupStatsScaling
from brevitas.core.scaling.runtime import RuntimeStatsScaling
from brevitas.core.scaling.runtime import StatsFromParameterScaling
from brevitas.core.stats.stats_op import AbsMax
from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE

SCALING_MIN_VAL = 1e-6


def test_scaling_min_val_parameter():
inp = torch.zeros(1, 5, requires_grad=True)
scaling_op = StatsFromParameterScaling(
scaling_stats_impl=AbsMax(),
scaling_stats_input_view_shape_impl=Identity(),
scaling_stats_input_concat_dim=None,
tracked_parameter_list=[inp],
scaling_shape=SCALAR_SHAPE,
restrict_scaling_impl=PowerOfTwoRestrictValue(),
scaling_min_val=SCALING_MIN_VAL)
pre_scale = scaling_op(inp)
pre_scale.sum().backward()
assert not torch.isnan(inp.grad).any()


def test_scaling_min_val_runtime():
inp = torch.zeros(1, 5, requires_grad=True)
scaling_op = RuntimeStatsScaling(
scaling_stats_impl=AbsMax(),
scaling_stats_input_view_shape_impl=Identity(),
scaling_shape=SCALAR_SHAPE,
restrict_scaling_impl=PowerOfTwoRestrictValue(),
scaling_min_val=SCALING_MIN_VAL)
pre_scale = scaling_op(inp)
pre_scale.sum().backward()
assert not torch.isnan(inp.grad).any()


def test_scaling_min_val_dynamic_group():
inp = torch.zeros(1, 6, requires_grad=True)
scaling_op = RuntimeDynamicGroupStatsScaling(
group_size=3,
group_dim=1,
input_view_impl=Identity(),
scaling_min_val=SCALING_MIN_VAL,
restrict_scaling_impl=PowerOfTwoRestrictValue(),
scaling_stats_impl=AbsMax())
pre_scale = scaling_op(inp)
pre_scale.sum().backward()
assert not torch.isnan(inp.grad).any()
38 changes: 38 additions & 0 deletions tests/brevitas/core/test_standalone_scaling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import warnings

import torch

from brevitas.core.function_wrapper.misc import Identity
from brevitas.core.restrict_val import PowerOfTwoRestrictValue
from brevitas.core.scaling import ParameterFromRuntimeStatsScaling
from brevitas.core.scaling.standalone import ParameterFromStatsFromParameterScaling
from brevitas.core.stats.stats_op import AbsMax
from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE

SCALING_MIN_VAL = 1e-6


def test_scaling_state_dict():
Expand All @@ -12,3 +20,33 @@ def test_scaling_state_dict():
scaling_op.state_dict()
for w in wlist:
assert "Positional args are being deprecated" not in str(w.message)


@torch.no_grad()
def test_scaling_min_val_runtime():
scaling_op = ParameterFromRuntimeStatsScaling(
collect_stats_steps=1,
scaling_stats_impl=AbsMax(),
scaling_min_val=SCALING_MIN_VAL,
restrict_scaling_impl=PowerOfTwoRestrictValue())
inp = torch.zeros(1, 5)
pre_scale = scaling_op(inp)
value_scale_converted = scaling_op(inp)
scaling_op.eval()
assert not torch.isinf(scaling_op.value).any()


@torch.no_grad()
def test_scaling_min_val_param():
inp = torch.zeros(1, 5)
scaling_op = ParameterFromStatsFromParameterScaling(
scaling_stats_impl=AbsMax(),
scaling_min_val=SCALING_MIN_VAL,
restrict_scaling_impl=PowerOfTwoRestrictValue(),
scaling_stats_input_view_shape_impl=Identity(),
scaling_stats_input_concat_dim=None,
tracked_parameter_list=[inp],
scaling_shape=SCALAR_SHAPE)
pre_scale = scaling_op(inp)
value_scale_converted = scaling_op(inp)
assert not torch.isinf(scaling_op.value).any()

0 comments on commit abf4a40

Please sign in to comment.