Skip to content

Commit

Permalink
More fixes to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 19, 2024
1 parent 170ef90 commit 863a00c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
11 changes: 5 additions & 6 deletions tests/brevitas/core/test_runtime_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,44 @@
from brevitas.core.stats.stats_op import AbsMax
from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE

SCALING_MIN_VAL = 1e-6


def test_scaling_min_val_parameter():
inp = torch.zeros(1, 5, requires_grad=True)
scaling_min_val = torch.tensor(1e-6)
scaling_op = StatsFromParameterScaling(
scaling_stats_impl=AbsMax(),
scaling_stats_input_view_shape_impl=Identity(),
scaling_stats_input_concat_dim=None,
tracked_parameter_list=[inp],
scaling_shape=SCALAR_SHAPE,
restrict_scaling_impl=PowerOfTwoRestrictValue(),
scaling_min_val=scaling_min_val)
scaling_min_val=SCALING_MIN_VAL)
pre_scale = scaling_op(inp)
pre_scale.sum().backward()
assert not torch.isnan(inp.grad).any()


def test_scaling_min_val_runtime():
inp = torch.zeros(1, 5, requires_grad=True)
scaling_min_val = torch.tensor(1e-6)
scaling_op = RuntimeStatsScaling(
scaling_stats_impl=AbsMax(),
scaling_stats_input_view_shape_impl=Identity(),
scaling_shape=SCALAR_SHAPE,
restrict_scaling_impl=PowerOfTwoRestrictValue(),
scaling_min_val=scaling_min_val)
scaling_min_val=SCALING_MIN_VAL)
pre_scale = scaling_op(inp)
pre_scale.sum().backward()
assert not torch.isnan(inp.grad).any()


def test_scaling_min_val_dynamic_group():
inp = torch.zeros(1, 6, requires_grad=True)
scaling_min_val = torch.tensor(1e-6)
scaling_op = RuntimeDynamicGroupStatsScaling(
group_size=3,
group_dim=1,
input_view_impl=Identity(),
scaling_min_val=scaling_min_val,
scaling_min_val=SCALING_MIN_VAL,
restrict_scaling_impl=PowerOfTwoRestrictValue(),
scaling_stats_impl=AbsMax())
pre_scale = scaling_op(inp)
Expand Down
8 changes: 4 additions & 4 deletions tests/brevitas/core/test_standalone_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from brevitas.core.stats.stats_op import AbsMax
from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE

SCALING_MIN_VAL = 1e-6


def test_scaling_state_dict():
scaling_op = ParameterFromRuntimeStatsScaling(
Expand All @@ -22,11 +24,10 @@ def test_scaling_state_dict():

@torch.no_grad()
def test_scaling_min_val_runtime():
scaling_min_val = 1e-6
scaling_op = ParameterFromRuntimeStatsScaling(
collect_stats_steps=1,
scaling_stats_impl=AbsMax(),
scaling_min_val=scaling_min_val,
scaling_min_val=SCALING_MIN_VAL,
restrict_scaling_impl=PowerOfTwoRestrictValue())
inp = torch.zeros(1, 5)
pre_scale = scaling_op(inp)
Expand All @@ -38,10 +39,9 @@ def test_scaling_min_val_runtime():
@torch.no_grad()
def test_scaling_min_val_param():
inp = torch.zeros(1, 5)
scaling_min_val = 1e-6
scaling_op = ParameterFromStatsFromParameterScaling(
scaling_stats_impl=AbsMax(),
scaling_min_val=scaling_min_val,
scaling_min_val=SCALING_MIN_VAL,
restrict_scaling_impl=PowerOfTwoRestrictValue(),
scaling_stats_input_view_shape_impl=Identity(),
scaling_stats_input_concat_dim=None,
Expand Down

0 comments on commit 863a00c

Please sign in to comment.