Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 20, 2023
1 parent e2c277d commit dabb8c5
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 13 deletions.
3 changes: 1 addition & 2 deletions tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from brevitas import torch_version
from brevitas.graph.equalize import _cross_layer_equalization
from brevitas.graph.equalize import _organize_region

SEED = 123456
ATOL = 1e-3
Expand All @@ -29,7 +28,7 @@
IN_SIZE_LINEAR = (1, 224, 3)


def equalize_test(model, regions, merge_bias, bias_shrinkage, scale_computation_type):
def equalize_test(regions, merge_bias, bias_shrinkage, scale_computation_type):
scale_factors_regions = []
for i in range(3):
for region in regions:
Expand Down
14 changes: 3 additions & 11 deletions tests/brevitas/graph/test_equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_resnet18_equalization():
model_orig = copy.deepcopy(model)
regions = _extract_regions(model)
_ = equalize_test(
model, regions, merge_bias=True, bias_shrinkage='vaiq', scale_computation_type='maxabs')
regions, merge_bias=True, bias_shrinkage='vaiq', scale_computation_type='maxabs')
out = model(inp)

# Check that equalization is not introducing FP variations
Expand Down Expand Up @@ -75,11 +75,7 @@ def test_equalization_torchvision_models(model_coverage: tuple, merge_bias: bool

regions = _extract_regions(model)
scale_factor_regions = equalize_test(
model,
regions,
merge_bias=merge_bias,
bias_shrinkage='vaiq',
scale_computation_type='maxabs')
regions, merge_bias=merge_bias, bias_shrinkage='vaiq', scale_computation_type='maxabs')
shape_scale_regions = [scale.shape for scale in scale_factor_regions]

out = model(inp)
Expand Down Expand Up @@ -132,11 +128,7 @@ def test_models(toy_model, merge_bias, request):
model = symbolic_trace(model)
regions = _extract_regions(model)
scale_factor_regions = equalize_test(
model,
regions,
merge_bias=merge_bias,
bias_shrinkage='vaiq',
scale_computation_type='maxabs')
regions, merge_bias=merge_bias, bias_shrinkage='vaiq', scale_computation_type='maxabs')
shape_scale_regions = [scale.shape for scale in scale_factor_regions]

with torch.no_grad():
Expand Down

0 comments on commit dabb8c5

Please sign in to comment.