diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index f47c24434..e67fb4f63 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -36,7 +36,7 @@ def test_resnet18_equalization(): # Check that equalization is not introducing FP variations assert torch.allclose(expected_out, out, atol=ATOL) - regions = sorted(regions, key=lambda region: sorted([r for r in region.srcs.keys()])) + regions = sorted(regions, key=lambda region: sorted([r for r in region.srcs_names])) resnet_18_regions = sorted(RESNET_18_REGIONS, key=lambda region: region[0][0]) equalized_layers = set() for r in resnet_18_regions: @@ -45,9 +45,9 @@ def test_resnet18_equalization(): # Check that we found all the expected regions for region, expected_region in zip(regions, resnet_18_regions): - srcs = list(region.srcs) + srcs = region.srcs_names sources_check = set(srcs) == set(expected_region[0]) - sinks = list(region.sinks) + sinks = region.sinks_names sinks_check = set(sinks) == set(expected_region[1]) assert sources_check assert sinks_check