diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index fde3a60d9..263543a82 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -11,7 +11,6 @@ from brevitas import torch_version from brevitas.graph.equalize import _cross_layer_equalization -from brevitas.graph.equalize import _organize_region SEED = 123456 ATOL = 1e-3 @@ -29,7 +28,7 @@ IN_SIZE_LINEAR = (1, 224, 3) -def equalize_test(model, regions, merge_bias, bias_shrinkage, scale_computation_type): +def equalize_test(regions, merge_bias, bias_shrinkage, scale_computation_type): scale_factors_regions = [] for i in range(3): for region in regions: diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index e67fb4f63..4f713211c 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -30,7 +30,7 @@ def test_resnet18_equalization(): model_orig = copy.deepcopy(model) regions = _extract_regions(model) _ = equalize_test( - model, regions, merge_bias=True, bias_shrinkage='vaiq', scale_computation_type='maxabs') + regions, merge_bias=True, bias_shrinkage='vaiq', scale_computation_type='maxabs') out = model(inp) # Check that equalization is not introducing FP variations @@ -75,11 +75,7 @@ def test_equalization_torchvision_models(model_coverage: tuple, merge_bias: bool regions = _extract_regions(model) scale_factor_regions = equalize_test( - model, - regions, - merge_bias=merge_bias, - bias_shrinkage='vaiq', - scale_computation_type='maxabs') + regions, merge_bias=merge_bias, bias_shrinkage='vaiq', scale_computation_type='maxabs') shape_scale_regions = [scale.shape for scale in scale_factor_regions] out = model(inp) @@ -132,11 +128,7 @@ def test_models(toy_model, merge_bias, request): model = symbolic_trace(model) regions = _extract_regions(model) scale_factor_regions = equalize_test( - model, - regions, - merge_bias=merge_bias, - bias_shrinkage='vaiq', - scale_computation_type='maxabs') + regions, merge_bias=merge_bias, bias_shrinkage='vaiq', scale_computation_type='maxabs') shape_scale_regions = [scale.shape for scale in scale_factor_regions] with torch.no_grad():