diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index b982dfd718..c99f45bb44 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -786,26 +786,22 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, # NOTE: In certain cases the corresponding subset might be None, in this case # we assume that the whole array is written, which is the default behaviour. ac_desc = n.desc(self.sdfg) - ac_size = ac_desc.total_size - in_subsets = dict() - for in_edge in in_edges: - # Ensure that if the destination subset is not given, our assumption, that the - # whole array is written to, is valid, by testing if the memlet transfers the - # whole array. - assert (in_edge.data.dst_subset is not None) or (in_edge.data.num_elements() == ac_size) - in_subsets[in_edge] = ( - sbs.Range.from_array(ac_desc) - if in_edge.data.dst_subset is None - else in_edge.data.dst_subset + in_subsets = { + in_edge: ( + sbs.Range.from_array(ac_desc) + if in_edge.data.dst_subset is None + else in_edge.data.dst_subset ) - out_subsets = dict() - for out_edge in out_edges: - assert (out_edge.data.src_subset is not None) or (out_edge.data.num_elements() == ac_size) - out_subsets[out_edge] = ( + for in_edge in in_edges + } + out_subsets = { + out_edge: ( sbs.Range.from_array(ac_desc) if out_edge.data.src_subset is None else out_edge.data.src_subset ) + for out_edge in out_edges + } # Update the read and write sets of the subgraph. if in_edges: diff --git a/tests/npbench/misc/stockham_fft_test.py b/tests/npbench/misc/stockham_fft_test.py index 8fc5e88203..5878cf621a 100644 --- a/tests/npbench/misc/stockham_fft_test.py +++ b/tests/npbench/misc/stockham_fft_test.py @@ -185,4 +185,4 @@ def test_fpga(): elif target == "gpu": run_stockham_fft(dace.dtypes.DeviceType.GPU) elif target == "fpga": - run_stockham_fft(dace.dtypes.DeviceType.FPGA) \ No newline at end of file + run_stockham_fft(dace.dtypes.DeviceType.FPGA) diff --git a/tests/sdfg/state_test.py b/tests/sdfg/state_test.py index 4bde3788e0..dc8ede776f 100644 --- a/tests/sdfg/state_test.py +++ b/tests/sdfg/state_test.py @@ -145,6 +145,52 @@ def test_read_and_write_set_selection(): assert found_match, f"Could not find the subset '{exp}' only got '{computed_sets}'" +def test_read_and_write_set_names(): + sdfg = dace.SDFG('test_read_and_write_set_names') + state = sdfg.add_state(is_start_block=True) + + # The arrays use different symbols for their sizes, but they are known to be the + # same. This happens for example if the SDFG is the result of some automatic + # translation from another IR, such as GTIR in GT4Py. + names = ["A", "B"] + for name in names: + sdfg.add_symbol(f"{name}_size_0", dace.int32) + sdfg.add_symbol(f"{name}_size_1", dace.int32) + sdfg.add_array( + name, + shape=(f"{name}_size_0", f"{name}_size_1"), + dtype=dace.float64, + transient=False, + ) + A, B = (state.add_access(name) for name in names) + + # Print copy `A` into `B`. + # Because, `dst_subset` is `None` we expect that everything is transferred. + state.add_nedge( + A, + B, + dace.Memlet("A[0:A_size_0, 0:A_size_1]"), + ) + expected_read_set = { + "A": [sbs.Range.from_string("0:A_size_0, 0:A_size_1")], + } + expected_write_set = { + "B": [sbs.Range.from_string("0:B_size_0, 0:B_size_1")], + } + read_set, write_set = state._read_and_write_sets() + + for expected_sets, computed_sets in [(expected_read_set, read_set), (expected_write_set, write_set)]: + assert expected_sets.keys() == computed_sets.keys(), f"Expected the set to contain '{expected_sets.keys()}' but got '{computed_sets.keys()}'." + for access_data in expected_sets.keys(): + for exp in expected_sets[access_data]: + found_match = False + for res in computed_sets[access_data]: + if res == exp: + found_match = True + break + assert found_match, f"Could not find the subset '{exp}' only got '{computed_sets}'" + + def test_add_mapped_tasklet(): sdfg = dace.SDFG("test_add_mapped_tasklet") state = sdfg.add_state(is_start_block=True) @@ -173,5 +219,6 @@ def test_add_mapped_tasklet(): test_read_and_write_set_filter() test_read_write_set() test_read_write_set_y_formation() + test_read_and_write_set_names() test_deepcopy_state() test_add_mapped_tasklet()