From 2b6e81d95e4f67786847e2bd315884de0fcd385c Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 11 Dec 2023 17:14:18 +0000 Subject: [PATCH] Fix test --- tests/brevitas/graph/test_equalization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index f47c24434..e67fb4f63 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -36,7 +36,7 @@ def test_resnet18_equalization(): # Check that equalization is not introducing FP variations assert torch.allclose(expected_out, out, atol=ATOL) - regions = sorted(regions, key=lambda region: sorted([r for r in region.srcs.keys()])) + regions = sorted(regions, key=lambda region: sorted([r for r in region.srcs_names])) resnet_18_regions = sorted(RESNET_18_REGIONS, key=lambda region: region[0][0]) equalized_layers = set() for r in resnet_18_regions: @@ -45,9 +45,9 @@ def test_resnet18_equalization(): # Check that we found all the expected regions for region, expected_region in zip(regions, resnet_18_regions): - srcs = list(region.srcs) + srcs = region.srcs_names sources_check = set(srcs) == set(expected_region[0]) - sinks = list(region.sinks) + sinks = region.sinks_names sinks_check = set(sinks) == set(expected_region[1]) assert sources_check assert sinks_check