Skip to content

Commit

Permalink
extend support for sparse fields (exclude neighbors)
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Jan 24, 2025
1 parent 16af227 commit ccd70db
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 6 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ markers = [
'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields',
'uses_cartesian_shift: tests that use a Cartesian connectivity',
'uses_unstructured_shift: tests that use a unstructured connectivity',
'uses_unstructured_shift_with_sparse_fields: tests that use a connectivity with sparse fields',
'uses_max_over: tests that use the max_over builtin',
'uses_mesh_with_skip_values: tests that use a mesh with skip values',
'checks_specific_error: tests that rely on the backend to produce a specific error message'
Expand Down
106 changes: 100 additions & 6 deletions src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class IteratorExpr:

field: dace.nodes.AccessNode
gt_dtype: ts.ListType | ts.ScalarType
field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymExpr]]
field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymbolicType]]
indices: dict[gtx_common.Dimension, DataExpr]

def get_field_type(self) -> ts.FieldType:
Expand Down Expand Up @@ -508,9 +508,6 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr:
memlet; otherwise dereferencing is a runtime operation represented in
the SDFG as a tasklet node.
"""
# format used for field index tasklet connector
IndexConnectorFmt: Final = "__index_{dim}"

if isinstance(node.type, ts.TupleType):
raise NotImplementedError("Tuple deref not supported.")

Expand All @@ -537,8 +534,21 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr:
# we use a tasklet to dereference an iterator when one or more indices are the result of some computation,
# either indirection through connectivity table or dynamic cartesian offset.
assert all(dim in arg_expr.indices for dim, _ in arg_expr.field_domain)
assert len(field_desc.shape) == len(arg_expr.field_domain)
if isinstance(arg_expr.gt_dtype, ts.ScalarType):
assert len(field_desc.shape) == len(arg_expr.field_domain)
return self._deref_scalar(arg_expr, field_desc)
else:
# expect one extra dimension in dace array for the local list
assert len(field_desc.shape) == len(arg_expr.field_domain) + 1
return self._deref_list(arg_expr, field_desc)

def _deref_scalar(self, arg_expr: IteratorExpr, field_desc: dace.data.Array) -> ValueExpr:
field_offset = [offset for (_, offset) in arg_expr.field_domain]
field_indices = [(dim, arg_expr.indices[dim]) for dim, _ in arg_expr.field_domain]

# format used for field index tasklet connector
IndexConnectorFmt: Final = "__index_{dim}"

index_connectors = [
IndexConnectorFmt.format(dim=dim.value)
for dim, index in field_indices
Expand All @@ -565,7 +575,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr:
dace_subsets.Range.from_array(field_desc),
deref_node,
"field",
src_offset=[offset for (_, offset) in arg_expr.field_domain],
src_offset=field_offset,
)

for dim, index_expr in field_indices:
Expand All @@ -592,6 +602,87 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr:

return self._construct_tasklet_result(field_desc.dtype, deref_node, "val")

def _deref_list(self, arg_expr: IteratorExpr, field_desc: dace.data.Array) -> ValueExpr:
assert isinstance(arg_expr.gt_dtype, ts.ListType)
assert arg_expr.gt_dtype.offset_type is not None
offset_type = arg_expr.gt_dtype.offset_type
offset_provider_type = self.subgraph_builder.get_offset_provider_type(offset_type.value)
assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType)
list_size = offset_provider_type.max_neighbors

field_offset = [offset for (_, offset) in arg_expr.field_domain] + [0]
field_indices = [(dim, arg_expr.indices[dim]) for dim, _ in arg_expr.field_domain]

# format used for field index tasklet connector
IndexConnectorFmt: Final = "__index_{dim}"

index_connectors = [
IndexConnectorFmt.format(dim=dim.value)
for dim, index in field_indices
if not isinstance(index, SymbolExpr)
]
# here `internals` refer to the names used as index in the tasklet code string:
# an index can be either a connector name (for dynamic/indirect indices)
# or a symbol value (for literal values and scalar arguments).
index_internals = ",".join(
str(index.value)
if isinstance(index, SymbolExpr)
else IndexConnectorFmt.format(dim=dim.value)
for dim, index in field_indices
)
deref_node = self._add_tasklet(
"runtime_deref",
{"field"} | set(index_connectors),
{"val"},
code=f"""
for i in range({list_size}):
val[i] = field[{index_internals}, i]
""",
)
# add new termination point for the field parameter
self._add_input_data_edge(
arg_expr.field,
dace_subsets.Range.from_array(field_desc),
deref_node,
"field",
src_offset=field_offset,
)

for dim, index_expr in field_indices:
# add termination points for the dynamic iterator indices
deref_connector = IndexConnectorFmt.format(dim=dim.value)
if isinstance(index_expr, MemletExpr):
self._add_input_data_edge(
index_expr.dc_node,
index_expr.subset,
deref_node,
deref_connector,
)

elif isinstance(index_expr, ValueExpr):
self._add_edge(
index_expr.dc_node,
None,
deref_node,
deref_connector,
dace.Memlet(data=index_expr.dc_node.data, subset="0"),
)
else:
assert isinstance(index_expr, SymbolExpr)

result, result_desc = self.subgraph_builder.add_temp_array(
self.sdfg, (list_size,), field_desc.dtype
)
result_node = self.state.add_access(result)
self._add_edge(
deref_node,
"val",
result_node,
None,
dace.Memlet.from_array(result, result_desc),
)
return ValueExpr(result_node, arg_expr.gt_dtype)

def _visit_if_branch_arg(
self,
if_sdfg: dace.SDFG,
Expand Down Expand Up @@ -879,6 +970,9 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr:
assert isinstance(origin_index, SymbolExpr)
assert all(isinstance(index, SymbolExpr) for index in it.indices.values())

if isinstance(it.gt_dtype, ts.ListType):
raise NotImplementedError("Calling neighbors on a sparse field is not supported.")

field_desc = it.field.desc(self.sdfg)
connectivity = gtx_dace_utils.connectivity_identifier(offset)
# initially, the storage for the connectivty tables is created as transient;
Expand Down
2 changes: 2 additions & 0 deletions tests/next_tests/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
USES_ZERO_DIMENSIONAL_FIELDS = "uses_zero_dimensional_fields"
USES_CARTESIAN_SHIFT = "uses_cartesian_shift"
USES_UNSTRUCTURED_SHIFT = "uses_unstructured_shift"
USES_UNSTRUCTURED_SHIFT_WITH_SPARSE_FIELDS = "uses_unstructured_shift_with_sparse_fields"
USES_MAX_OVER = "uses_max_over"
USES_MESH_WITH_SKIP_VALUES = "uses_mesh_with_skip_values"
USES_SCALAR_IN_DOMAIN_AND_FO = "uses_scalar_in_domain_and_fo"
Expand Down Expand Up @@ -149,6 +150,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
(USES_REDUCE_WITH_LAMBDA, XFAIL, UNSUPPORTED_MESSAGE),
(USES_SCAN_IN_STENCIL, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE),
(USES_TUPLE_ITERATOR, XFAIL, UNSUPPORTED_MESSAGE),
(USES_UNSTRUCTURED_SHIFT_WITH_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE),
]
)
EMBEDDED_SKIP_LIST = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ def sparse_shifted_stencil(inp):


@pytest.mark.uses_sparse_fields
@pytest.mark.uses_unstructured_shift_with_sparse_fields
def test_shift_sparse_input_field(program_processor):
program_processor, validate = program_processor
inp = gtx.as_field([Vertex, V2VDim], v2v_arr)
Expand Down Expand Up @@ -385,7 +386,9 @@ def shift_sparse_stencil2(inp):
return list_get(1, list_get(3, neighbors(V2E, inp)))


@pytest.mark.uses_composite_shifts
@pytest.mark.uses_sparse_fields
@pytest.mark.uses_unstructured_shift_with_sparse_fields
def test_shift_sparse_input_field2(program_processor):
program_processor, validate = program_processor
if program_processor in [
Expand Down

0 comments on commit ccd70db

Please sign in to comment.