diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 63e9fb03dc..525a5c694e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -461,7 +461,7 @@ def _visit_scan_stencil_closure( assert isinstance(node.output, SymRef) neighbor_tables = filter_neighbor_tables(self.offset_provider) input_names = [str(inp.id) for inp in node.inputs] - connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] + connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()] # find the scan dimension, same as output dimension, and exclude it from the map domain map_ranges = {} diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index ab03d29389..ba969608a7 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -763,7 +763,6 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: # already a list of ValueExpr return iterator - args: list[ValueExpr] sorted_dims = sorted(iterator.dimensions) if all([dim in iterator.indices for dim in iterator.dimensions]): # The deref iterator has index values on all dimensions: the result will be a scalar @@ -781,16 +780,16 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: ) else: - # Not all dimensions are included in the deref index list: - # this means the ND-field will be sliced along one or more dimensions and the result will be an array - field_array = self.context.body.arrays[iterator.field.data] - result_shape = tuple( - dim_size - for dim, dim_size in zip(sorted_dims, field_array.shape) - if dim not in iterator.indices - ) + dims_not_indexed = [dim for dim in iterator.dimensions if dim not in iterator.indices] + assert len(dims_not_indexed) == 1 + offset = dims_not_indexed[0] + offset_provider = self.offset_provider[offset] + neighbor_dim = offset_provider.neighbor_axis.value + result_name = unique_var_name() - self.context.body.add_array(result_name, result_shape, iterator.dtype, transient=True) + self.context.body.add_array( + result_name, (offset_provider.max_neighbors,), iterator.dtype, transient=True + ) result_array = self.context.body.arrays[result_name] result_node = self.context.state.add_access(result_name, debuginfo=di) @@ -800,19 +799,17 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: deref_nodes = [iterator.field] + [ iterator.indices[dim] for dim in sorted_dims if dim in iterator.indices ] - deref_memlets = [dace.Memlet.from_array(iterator.field.data, field_array)] + [ - dace.Memlet(data=node.data, subset="0") for node in deref_nodes[1:] - ] + deref_memlets = [ + dace.Memlet.from_array(iterator.field.data, iterator.field.desc(self.context.body)) + ] + [dace.Memlet(data=node.data, subset="0") for node in deref_nodes[1:]] # we create a mapped tasklet for array slicing + index_name = unique_name(f"_i_{neighbor_dim}") map_ranges = { - f"_i_{dim}": f"0:{size}" - for dim, size in zip(sorted_dims, field_array.shape) - if dim not in iterator.indices + index_name: f"0:{offset_provider.max_neighbors}", } - src_subset = ",".join([f"_i_{dim}" for dim in sorted_dims]) - dst_subset = ",".join( - [f"_i_{dim}" for dim in sorted_dims if dim not in iterator.indices] + src_subset = ",".join( + [f"_i_{dim}" if dim in iterator.indices else index_name for dim in sorted_dims] ) self.context.state.add_mapped_tasklet( "deref", @@ -821,7 +818,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: outputs={ "_out": dace.Memlet.from_array(result_name, result_array), }, - code=f"_out[{dst_subset}] = _inp[{src_subset}]", + code=f"_out[{index_name}] = _inp[{src_subset}]", external_edges=True, input_nodes={node.data: node for node in deref_nodes}, output_nodes={ @@ -952,10 +949,10 @@ def _visit_reduce(self, node: itir.FunCall): # set reduction state self.context.reduce_identity = SymbolExpr(reduce_identity, reduce_dtype) - args = self.visit(node.args) + args = self.visit(node.args[0]) - assert len(args) == 1 and len(args[0]) == 1 - reduce_input_node = args[0][0].value + assert len(args) == 1 + reduce_input_node = args[0].value else: assert isinstance(node.fun, itir.FunCall)