Skip to content

Commit

Permalink
fix[next][dace]: Bugfix in deref (dynamic memory allocation) (#1430)
Browse files Browse the repository at this point in the history
Baseline contained a bug in the lowering of deref in the context of neighbor reduction. The data container should be statically allocated with size equal to the max_neighbors attribute in the offset provider.
  • Loading branch information
edopao authored Jan 30, 2024
1 parent f0986bb commit eb43002
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def _visit_scan_stencil_closure(
assert isinstance(node.output, SymRef)
neighbor_tables = filter_neighbor_tables(self.offset_provider)
input_names = [str(inp.id) for inp in node.inputs]
connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables]
connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()]

# find the scan dimension, same as output dimension, and exclude it from the map domain
map_ranges = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,6 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]:
# already a list of ValueExpr
return iterator

args: list[ValueExpr]
sorted_dims = sorted(iterator.dimensions)
if all([dim in iterator.indices for dim in iterator.dimensions]):
# The deref iterator has index values on all dimensions: the result will be a scalar
Expand All @@ -781,16 +780,16 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]:
)

else:
# Not all dimensions are included in the deref index list:
# this means the ND-field will be sliced along one or more dimensions and the result will be an array
field_array = self.context.body.arrays[iterator.field.data]
result_shape = tuple(
dim_size
for dim, dim_size in zip(sorted_dims, field_array.shape)
if dim not in iterator.indices
)
dims_not_indexed = [dim for dim in iterator.dimensions if dim not in iterator.indices]
assert len(dims_not_indexed) == 1
offset = dims_not_indexed[0]
offset_provider = self.offset_provider[offset]
neighbor_dim = offset_provider.neighbor_axis.value

result_name = unique_var_name()
self.context.body.add_array(result_name, result_shape, iterator.dtype, transient=True)
self.context.body.add_array(
result_name, (offset_provider.max_neighbors,), iterator.dtype, transient=True
)
result_array = self.context.body.arrays[result_name]
result_node = self.context.state.add_access(result_name, debuginfo=di)

Expand All @@ -800,19 +799,17 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]:
deref_nodes = [iterator.field] + [
iterator.indices[dim] for dim in sorted_dims if dim in iterator.indices
]
deref_memlets = [dace.Memlet.from_array(iterator.field.data, field_array)] + [
dace.Memlet(data=node.data, subset="0") for node in deref_nodes[1:]
]
deref_memlets = [
dace.Memlet.from_array(iterator.field.data, iterator.field.desc(self.context.body))
] + [dace.Memlet(data=node.data, subset="0") for node in deref_nodes[1:]]

# we create a mapped tasklet for array slicing
index_name = unique_name(f"_i_{neighbor_dim}")
map_ranges = {
f"_i_{dim}": f"0:{size}"
for dim, size in zip(sorted_dims, field_array.shape)
if dim not in iterator.indices
index_name: f"0:{offset_provider.max_neighbors}",
}
src_subset = ",".join([f"_i_{dim}" for dim in sorted_dims])
dst_subset = ",".join(
[f"_i_{dim}" for dim in sorted_dims if dim not in iterator.indices]
src_subset = ",".join(
[f"_i_{dim}" if dim in iterator.indices else index_name for dim in sorted_dims]
)
self.context.state.add_mapped_tasklet(
"deref",
Expand All @@ -821,7 +818,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]:
outputs={
"_out": dace.Memlet.from_array(result_name, result_array),
},
code=f"_out[{dst_subset}] = _inp[{src_subset}]",
code=f"_out[{index_name}] = _inp[{src_subset}]",
external_edges=True,
input_nodes={node.data: node for node in deref_nodes},
output_nodes={
Expand Down Expand Up @@ -952,10 +949,10 @@ def _visit_reduce(self, node: itir.FunCall):
# set reduction state
self.context.reduce_identity = SymbolExpr(reduce_identity, reduce_dtype)

args = self.visit(node.args)
args = self.visit(node.args[0])

assert len(args) == 1 and len(args[0]) == 1
reduce_input_node = args[0][0].value
assert len(args) == 1
reduce_input_node = args[0].value

else:
assert isinstance(node.fun, itir.FunCall)
Expand Down

0 comments on commit eb43002

Please sign in to comment.