diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index bb95244aef..cc845505c9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -551,11 +551,17 @@ def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]: def_locations: list[AccessLocation] = [] for upstream_state in find_upstream_states(temp_storage_state): if temp_storage_node.data in access_sets[upstream_state][1]: - def_locations.extend( + # NOTE: We do not impose any restriction on `temp_storage`. Thus + # It could be that we do read from it (we can never write to it) + # in this state or any other state later. + # TODO(phimuell): Should we require that `temp_storage` is a sink + # node? It might prevent or allow other optimizations. + new_locations = [ (data_node, upstream_state) for data_node in upstream_state.data_nodes() if data_node.data == temp_storage_node.data - ) + ] + def_locations.extend(new_locations) if len(def_locations) != 0: result_candidates.append((temp_storage, def_locations)) @@ -677,7 +683,6 @@ def _check_read_write_dependency_impl( # Get the location and the state where the temporary is originally defined. def_location_of_intermediate, state_to_inspect = target_location - assert state_to_inspect.out_degree(def_location_of_intermediate) == 0 # These are all access nodes that refers to the global data, that we want # to move into the state `state_to_inspect`. We need them to do the @@ -689,6 +694,8 @@ def _check_read_write_dependency_impl( # empty Memlets. This is done because such Memlets are used to induce a # schedule or order in the dataflow graph. # As a byproduct, for the second test, we also collect all of these nodes. + # TODO(phimuell): Refine this such that it takes the location of the data + # into account. for dnode in state_to_inspect.data_nodes(): if dnode.data != global_data_name: continue diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py index d61b8a2d42..709079dd0d 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py @@ -287,3 +287,63 @@ def test_distributed_buffer_global_memory_data_no_rance2(): res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) assert res[sdfg]["DistributedBufferRelocator"][state2] == {"t"} assert state2.number_of_nodes() == 0 + + +def _make_distributed_buffer_non_sink_temporary_sdfg() -> ( + tuple[dace.SDFG, dace.SDFGState, dace.SDFGState] +): + sdfg = dace.SDFG(util.unique_name("distributed_buffer_non_sink_temporary_sdfg")) + state = sdfg.add_state(is_start_block=True) + wb_state = sdfg.add_state_after(state) + + names = ["a", "b", "c", "t1", "t2"] + for name in names: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["t1"].transient = True + sdfg.arrays["t2"].transient = True + t1 = state.add_access("t1") + + state.add_mapped_tasklet( + "comp1", + map_ranges={"__i": "0:10"}, + inputs={"__in1": dace.Memlet("a[__i]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("t1[__i]")}, + output_nodes={t1}, + external_edges=True, + ) + state.add_mapped_tasklet( + "comp2", + map_ranges={"__i": "0:10"}, + inputs={"__in1": dace.Memlet("t1[__i]")}, + code="__out = __in1 / 2.0", + outputs={"__out": dace.Memlet("t2[__i]")}, + input_nodes={t1}, + external_edges=True, + ) + + wb_state.add_nedge(wb_state.add_access("t1"), wb_state.add_access("b"), dace.Memlet("t1[0:10]")) + wb_state.add_nedge(wb_state.add_access("t2"), wb_state.add_access("b"), dace.Memlet("t2[0:10]")) + + sdfg.validate() + return sdfg, state, wb_state + + +def test_distributed_buffer_non_sink_temporary(): + """Tests the transformation if one of the temporaries is not a sink node. + + Note that the SDFG has two temporaries, `t1` is not a sink node and `t2` is + a sink node. + """ + sdfg, state, wb_state = _make_distributed_buffer_non_sink_temporary_sdfg() + assert wb_state.number_of_nodes() == 4 + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + sdfg.view() + assert res[sdfg]["DistributedBufferRelocator"][wb_state] == {"t1", "t2"} + assert wb_state.number_of_nodes() == 0