From b93a4c962b642d836baa740da64bb4283f3a9f9b Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 22 Dec 2023 03:38:27 -0800 Subject: [PATCH] Complete coverage for reference-to-view pass (#1488) Adds a scoped test that completes coverage for the reference-to-view pass, leading to fixes of issues in the uncovered code. --- .../passes/reference_reduction.py | 7 ++- tests/sdfg/reference_test.py | 59 +++++++++++++++++++ 2 files changed, 63 insertions(+), 3 deletions(-) diff --git a/dace/transformation/passes/reference_reduction.py b/dace/transformation/passes/reference_reduction.py index 99bd2cea24..2af76852ba 100644 --- a/dace/transformation/passes/reference_reduction.py +++ b/dace/transformation/passes/reference_reduction.py @@ -121,7 +121,7 @@ def find_candidates( # Check if any of the symbols is a scope symbol entry = state.entry_node(node) while entry is not None: - if fsyms & entry.new_symbols(sdfg, state, {}): + if fsyms & entry.new_symbols(sdfg, state, {}).keys(): result.remove(cand) break entry = state.entry_node(entry) @@ -183,11 +183,12 @@ def remove_refsets( # Modify the state graph as necessary for e in edges_to_remove: - state.remove_edge_and_connectors(e) + state.remove_memlet_path(e) for n in nodes_to_remove: state.remove_node(n) for e in edges_to_add: - state.add_edge(*e) + if len(state.edges_between(e[0], e[2])) == 0: + state.add_edge(*e) for n in affected_nodes: # Orphaned nodes if n in nodes_to_remove: continue diff --git a/tests/sdfg/reference_test.py b/tests/sdfg/reference_test.py index 066bd80a7f..6c4d1eda1f 100644 --- a/tests/sdfg/reference_test.py +++ b/tests/sdfg/reference_test.py @@ -581,6 +581,61 @@ def test_reference_loop_nonfree_internal_use(): assert np.allclose(ref, A) +@pytest.mark.parametrize(('array_outside_scope', 'depends_on_iterate'), ((False, True), (False, True))) +def test_ref2view_refset_in_scope(array_outside_scope, depends_on_iterate): + sdfg = dace.SDFG('reftest') + sdfg.add_array('A', [20], dace.float64) + sdfg.add_array('B', [20], dace.float64) + sdfg.add_reference('ref', [1], dace.float64) + + memlet_string = 'A[i]' if depends_on_iterate else 'A[3]' + + state = sdfg.add_state() + me, mx = state.add_map('somemap', dict(i='0:20')) + arr = state.add_access('A') + ref = state.add_access('ref') + write = state.add_write('B') + + if array_outside_scope: + state.add_edge_pair(me, ref, arr, dace.Memlet(memlet_string), internal_connector='set') + else: + state.add_nedge(me, arr, dace.Memlet()) + state.add_edge(arr, None, ref, 'set', dace.Memlet(memlet_string)) + + t = state.add_tasklet('addone', {'inp'}, {'out'}, 'out = inp + 1') + state.add_edge(ref, None, t, 'inp', dace.Memlet('ref')) + state.add_edge_pair(mx, t, write, dace.Memlet('B[i]'), internal_connector='out') + + # Test sources + sources = FindReferenceSources().apply_pass(sdfg, {}) + assert len(sources) == 1 # There is only one SDFG + sources = sources[0] + assert len(sources) == 1 + assert sources['ref'] == {dace.Memlet(memlet_string)} + + # Test correctness before pass + A = np.random.rand(20) + B = np.random.rand(20) + ref = (A + 1) if depends_on_iterate else (A[3] + 1) + sdfg(A=A, B=B) + assert np.allclose(B, ref) + + # Test reference-to-view - should fail to apply + result = Pipeline([ReferenceToView()]).apply_pass(sdfg, {}) + if depends_on_iterate: + assert 'ReferenceToView' not in result or not result['ReferenceToView'] + else: + assert result['ReferenceToView'] == {'ref'} + + # Test correctness after pass + if not depends_on_iterate: + A = np.random.rand(20) + B = np.random.rand(20) + ref = (A + 1) if depends_on_iterate else (A[3] + 1) + sdfg(A=A, B=B) + assert np.allclose(B, ref) + + if __name__ == '__main__': test_unset_reference() test_reference_branch() @@ -603,3 +658,7 @@ def test_reference_loop_nonfree_internal_use(): test_reference_loop_internal_use(False) test_reference_loop_internal_use(True) test_reference_loop_nonfree_internal_use() + test_ref2view_refset_in_scope(False, False) + test_ref2view_refset_in_scope(False, True) + test_ref2view_refset_in_scope(True, False) + test_ref2view_refset_in_scope(True, True)