From 863a00c8537f8c31470212a68183e6b3d0359d22 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 19 Nov 2024 14:20:28 +0000 Subject: [PATCH] More fixes to tests --- tests/brevitas/core/test_runtime_scaling.py | 11 +++++------ tests/brevitas/core/test_standalone_scaling.py | 8 ++++---- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/tests/brevitas/core/test_runtime_scaling.py b/tests/brevitas/core/test_runtime_scaling.py index eac52c591..e40d5cee3 100644 --- a/tests/brevitas/core/test_runtime_scaling.py +++ b/tests/brevitas/core/test_runtime_scaling.py @@ -8,10 +8,11 @@ from brevitas.core.stats.stats_op import AbsMax from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE +SCALING_MIN_VAL = 1e-6 + def test_scaling_min_val_parameter(): inp = torch.zeros(1, 5, requires_grad=True) - scaling_min_val = torch.tensor(1e-6) scaling_op = StatsFromParameterScaling( scaling_stats_impl=AbsMax(), scaling_stats_input_view_shape_impl=Identity(), @@ -19,7 +20,7 @@ def test_scaling_min_val_parameter(): tracked_parameter_list=[inp], scaling_shape=SCALAR_SHAPE, restrict_scaling_impl=PowerOfTwoRestrictValue(), - scaling_min_val=scaling_min_val) + scaling_min_val=SCALING_MIN_VAL) pre_scale = scaling_op(inp) pre_scale.sum().backward() assert not torch.isnan(inp.grad).any() @@ -27,13 +28,12 @@ def test_scaling_min_val_parameter(): def test_scaling_min_val_runtime(): inp = torch.zeros(1, 5, requires_grad=True) - scaling_min_val = torch.tensor(1e-6) scaling_op = RuntimeStatsScaling( scaling_stats_impl=AbsMax(), scaling_stats_input_view_shape_impl=Identity(), scaling_shape=SCALAR_SHAPE, restrict_scaling_impl=PowerOfTwoRestrictValue(), - scaling_min_val=scaling_min_val) + scaling_min_val=SCALING_MIN_VAL) pre_scale = scaling_op(inp) pre_scale.sum().backward() assert not torch.isnan(inp.grad).any() @@ -41,12 +41,11 @@ def test_scaling_min_val_runtime(): def test_scaling_min_val_dynamic_group(): inp = torch.zeros(1, 6, requires_grad=True) - scaling_min_val = torch.tensor(1e-6) scaling_op = RuntimeDynamicGroupStatsScaling( group_size=3, group_dim=1, input_view_impl=Identity(), - scaling_min_val=scaling_min_val, + scaling_min_val=SCALING_MIN_VAL, restrict_scaling_impl=PowerOfTwoRestrictValue(), scaling_stats_impl=AbsMax()) pre_scale = scaling_op(inp) diff --git a/tests/brevitas/core/test_standalone_scaling.py b/tests/brevitas/core/test_standalone_scaling.py index 5c1b0257b..cb2010d5d 100644 --- a/tests/brevitas/core/test_standalone_scaling.py +++ b/tests/brevitas/core/test_standalone_scaling.py @@ -9,6 +9,8 @@ from brevitas.core.stats.stats_op import AbsMax from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE +SCALING_MIN_VAL = 1e-6 + def test_scaling_state_dict(): scaling_op = ParameterFromRuntimeStatsScaling( @@ -22,11 +24,10 @@ def test_scaling_state_dict(): @torch.no_grad() def test_scaling_min_val_runtime(): - scaling_min_val = 1e-6 scaling_op = ParameterFromRuntimeStatsScaling( collect_stats_steps=1, scaling_stats_impl=AbsMax(), - scaling_min_val=scaling_min_val, + scaling_min_val=SCALING_MIN_VAL, restrict_scaling_impl=PowerOfTwoRestrictValue()) inp = torch.zeros(1, 5) pre_scale = scaling_op(inp) @@ -38,10 +39,9 @@ def test_scaling_min_val_runtime(): @torch.no_grad() def test_scaling_min_val_param(): inp = torch.zeros(1, 5) - scaling_min_val = 1e-6 scaling_op = ParameterFromStatsFromParameterScaling( scaling_stats_impl=AbsMax(), - scaling_min_val=scaling_min_val, + scaling_min_val=SCALING_MIN_VAL, restrict_scaling_impl=PowerOfTwoRestrictValue(), scaling_stats_input_view_shape_impl=Identity(), scaling_stats_input_concat_dim=None,