Skip to content

Commit

Permalink
fix[next][dace]: Fix for DistributedBufferRelocator (#1814)
Browse files Browse the repository at this point in the history
This PR fixes a bug in `DistributedBufferRelocator` that was observed in
ICON4Py's `TestUpdateThetaAndExner` test.

In essence there was an `assert` that assumed that checked if this
temporary was a sink node, but, the code that finds all write backs was
never excluding such cases, i.e. the temporaries that were selected might
not be sink nodes in the state where they are defined.
The `assert` was not part of the original implementation and is not a
requirement of the transformation, instead it was introduced by
[PR#1799](#1799), that fixed some
issues in the analysis of read write dependencies.

There are two solutions for this, either removing the `assert` or prune
these kinds of temporaries. After some consideration, it was realized
that handling such cases will not lead to invalid SDFG, as long as the
other restrictions on the global data are respected. For that reason the
`assert` was removed.
However, we should thinking of doing something more intelligent in that
case.
  • Loading branch information
philip-paul-mueller authored Jan 22, 2025
1 parent 44578ec commit 9bbb952
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -551,11 +551,17 @@ def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]:
def_locations: list[AccessLocation] = []
for upstream_state in find_upstream_states(temp_storage_state):
if temp_storage_node.data in access_sets[upstream_state][1]:
def_locations.extend(
# NOTE: We do not impose any restriction on `temp_storage`. Thus
# It could be that we do read from it (we can never write to it)
# in this state or any other state later.
# TODO(phimuell): Should we require that `temp_storage` is a sink
# node? It might prevent or allow other optimizations.
new_locations = [
(data_node, upstream_state)
for data_node in upstream_state.data_nodes()
if data_node.data == temp_storage_node.data
)
]
def_locations.extend(new_locations)
if len(def_locations) != 0:
result_candidates.append((temp_storage, def_locations))

Expand Down Expand Up @@ -677,7 +683,6 @@ def _check_read_write_dependency_impl(

# Get the location and the state where the temporary is originally defined.
def_location_of_intermediate, state_to_inspect = target_location
assert state_to_inspect.out_degree(def_location_of_intermediate) == 0

# These are all access nodes that refers to the global data, that we want
# to move into the state `state_to_inspect`. We need them to do the
Expand All @@ -689,6 +694,8 @@ def _check_read_write_dependency_impl(
# empty Memlets. This is done because such Memlets are used to induce a
# schedule or order in the dataflow graph.
# As a byproduct, for the second test, we also collect all of these nodes.
# TODO(phimuell): Refine this such that it takes the location of the data
# into account.
for dnode in state_to_inspect.data_nodes():
if dnode.data != global_data_name:
continue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,63 @@ def test_distributed_buffer_global_memory_data_no_rance2():
res = gtx_transformations.gt_reduce_distributed_buffering(sdfg)
assert res[sdfg]["DistributedBufferRelocator"][state2] == {"t"}
assert state2.number_of_nodes() == 0


def _make_distributed_buffer_non_sink_temporary_sdfg() -> (
tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]
):
sdfg = dace.SDFG(util.unique_name("distributed_buffer_non_sink_temporary_sdfg"))
state = sdfg.add_state(is_start_block=True)
wb_state = sdfg.add_state_after(state)

names = ["a", "b", "c", "t1", "t2"]
for name in names:
sdfg.add_array(
name,
shape=(10,),
dtype=dace.float64,
transient=False,
)
sdfg.arrays["t1"].transient = True
sdfg.arrays["t2"].transient = True
t1 = state.add_access("t1")

state.add_mapped_tasklet(
"comp1",
map_ranges={"__i": "0:10"},
inputs={"__in1": dace.Memlet("a[__i]")},
code="__out = __in1 + 10.0",
outputs={"__out": dace.Memlet("t1[__i]")},
output_nodes={t1},
external_edges=True,
)
state.add_mapped_tasklet(
"comp2",
map_ranges={"__i": "0:10"},
inputs={"__in1": dace.Memlet("t1[__i]")},
code="__out = __in1 / 2.0",
outputs={"__out": dace.Memlet("t2[__i]")},
input_nodes={t1},
external_edges=True,
)

wb_state.add_nedge(wb_state.add_access("t1"), wb_state.add_access("b"), dace.Memlet("t1[0:10]"))
wb_state.add_nedge(wb_state.add_access("t2"), wb_state.add_access("b"), dace.Memlet("t2[0:10]"))

sdfg.validate()
return sdfg, state, wb_state


def test_distributed_buffer_non_sink_temporary():
"""Tests the transformation if one of the temporaries is not a sink node.
Note that the SDFG has two temporaries, `t1` is not a sink node and `t2` is
a sink node.
"""
sdfg, state, wb_state = _make_distributed_buffer_non_sink_temporary_sdfg()
assert wb_state.number_of_nodes() == 4

res = gtx_transformations.gt_reduce_distributed_buffering(sdfg)
sdfg.view()
assert res[sdfg]["DistributedBufferRelocator"][wb_state] == {"t1", "t2"}
assert wb_state.number_of_nodes() == 0

0 comments on commit 9bbb952

Please sign in to comment.