From 8fb0c09abd0834f7ecb6be965cb6ee8347d940b8 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 17 Jan 2025 14:14:51 +0100 Subject: [PATCH] Add support for ListType as scan output --- .../gtir_builtin_translators.py | 94 +++++++++++++------ 1 file changed, 64 insertions(+), 30 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 20e7dba1f1..ff75103cf7 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -310,27 +310,32 @@ def _create_field_operator_impl( domain_subset = dace_subsets.Range.from_indices(domain_indices) if isinstance(output_edge.result.gt_dtype, ts.ScalarType): - assert output_edge.result.gt_dtype == output_type.dtype - field_dtype = output_edge.result.gt_dtype - field_dims, field_shape, field_offset = (domain_dims, domain_shape, domain_offset) + if output_edge.result.gt_dtype != output_type.dtype: + raise TypeError( + f"Type mismatch, expected {output_type.dtype} got {output_edge.result.gt_dtype}." + ) + element_type = output_edge.result.gt_dtype + field_shape, field_offset = (domain_shape, domain_offset) assert isinstance(dataflow_output_desc, dace.data.Scalar) field_subset = domain_subset else: assert isinstance(output_type.dtype, ts.ListType) assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) - assert output_edge.result.gt_dtype.element_type == output_type.dtype.element_type - field_dtype = output_edge.result.gt_dtype.element_type + assert output_edge.result.gt_dtype.offset_type is not None + element_type = output_edge.result.gt_dtype.element_type + if element_type != output_type.dtype.element_type: + raise TypeError( + f"Type mismatch, expected {output_type.dtype.element_type} got {element_type}." + ) assert isinstance(dataflow_output_desc, dace.data.Array) assert len(dataflow_output_desc.shape) == 1 # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) - assert output_edge.result.gt_dtype.offset_type is not None - field_dims = [*domain_dims, output_edge.result.gt_dtype.offset_type] field_shape = [*domain_shape, dataflow_output_desc.shape[0]] field_offset = [*domain_offset, dataflow_output_desc.offset[0]] field_subset = domain_subset + dace_subsets.Range.from_array(dataflow_output_desc) # allocate local temporary storage - assert dataflow_output_desc.dtype == dace_utils.as_dace_type(field_dtype) + assert dataflow_output_desc.dtype == dace_utils.as_dace_type(element_type) field_name, _ = sdfg_builder.add_temp_array(sdfg, field_shape, dataflow_output_desc.dtype) field_node = state.add_access(field_name) @@ -339,7 +344,7 @@ def _create_field_operator_impl( return FieldopData( field_node, - ts.FieldType(field_dims, field_dtype), + ts.FieldType(domain_dims, output_edge.result.gt_dtype), offset=(field_offset if set(field_offset) != {0} else None), ) @@ -358,44 +363,70 @@ def _create_scan_field_operator_impl( Similar to `_create_scan_field_operator_impl()` but for scan field operators. """ dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) + assert isinstance(dataflow_output_desc, dace.data.Array) domain_dims, domain_offset, domain_shape = _get_field_layout(domain) domain_indices = _get_domain_indices(domain_dims, domain_offset) - domain_subset = dace_subsets.Range.from_indices(domain_indices) + + # the vertical dimension should not belong to the field operator domain + # but we need to write it to the output field + scan_dim_index = domain_dims.index(scan_dim) + + domain_subset = ( + dace_subsets.Range.from_indices(domain_indices[:scan_dim_index]) + + dace_subsets.Range.from_string(f"0:{dataflow_output_desc.shape[0]}") + + dace_subsets.Range.from_indices(domain_indices[scan_dim_index + 1 :]) + ) if isinstance(output_edge.result.gt_dtype, ts.ScalarType): - # the scan field operator produces a 1D vertical field - assert isinstance(dataflow_output_desc, dace.data.Array) + assert isinstance(output_type.dtype, ts.ScalarType) + if output_edge.result.gt_dtype != output_type.dtype: + raise TypeError( + f"Type mismatch, expected {output_type.dtype} got {output_edge.result.gt_dtype}." + ) + element_type = output_edge.result.gt_dtype + field_shape, field_offset = (domain_shape, domain_offset) + # the scan field operator computes a column of scalar values assert len(dataflow_output_desc.shape) == 1 - assert output_edge.result.gt_dtype == output_type.dtype - field_dtype = output_edge.result.gt_dtype - field_dims, field_shape, field_offset = (domain_dims, domain_shape, domain_offset) - # the vertical dimension should not belong to the field operator domain - # but we need to write it to the output field - scan_dim_index = domain_dims.index(scan_dim) - field_subset = ( - dace_subsets.Range(domain_subset[:scan_dim_index]) - + dace_subsets.Range.from_array(dataflow_output_desc) - + dace_subsets.Range(domain_subset[scan_dim_index + 1 :]) - ) + field_subset = domain_subset else: - raise NotImplementedError("List of values not supported in scan field operators.") + assert isinstance(output_type.dtype, ts.ListType) + assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) + assert output_edge.result.gt_dtype.offset_type is not None + element_type = output_edge.result.gt_dtype.element_type + if element_type != output_type.dtype.element_type: + raise TypeError( + f"Type mismatch, expected {output_type.dtype.element_type} got {element_type}." + ) + # the scan field operator computes a list of scalar values for each column level + assert len(dataflow_output_desc.shape) == 2 + # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) + field_shape = [*domain_shape, dataflow_output_desc.shape[1]] + field_offset = [*domain_offset, dataflow_output_desc.offset[1]] + field_subset = domain_subset + dace_subsets.Range.from_string( + f"0:{dataflow_output_desc.shape[1]}" + ) # allocate local temporary storage - assert dataflow_output_desc.dtype == dace_utils.as_dace_type(field_dtype) + assert dataflow_output_desc.dtype == dace_utils.as_dace_type(element_type) field_name, field_desc = sdfg_builder.add_temp_array( sdfg, field_shape, dataflow_output_desc.dtype ) + # the inner and outer strides have to match scan_output_stride = field_desc.strides[scan_dim_index] - dataflow_output_desc.strides = (scan_output_stride,) - field_node = state.add_access(field_name) + # also consider the stride of the local dimension, in case the scan field operator computes a list + local_strides = field_desc.strides[len(domain_dims) :] + assert len(local_strides) == (1 if isinstance(output_edge.result.gt_dtype, ts.ListType) else 0) + new_inner_strides = [scan_output_stride, *local_strides] + dataflow_output_desc.set_shape(dataflow_output_desc.shape, new_inner_strides) # and here the edge writing the dataflow result data through the map exit node + field_node = state.add_access(field_name) output_edge.connect(map_exit, field_node, field_subset) return FieldopData( field_node, - ts.FieldType(field_dims, field_dtype), + ts.FieldType(domain_dims, output_edge.result.gt_dtype), offset=(field_offset if set(field_offset) != {0} else None), ) @@ -436,9 +467,12 @@ def _create_field_operator( assert scan_dim is None or scan_dim in domain_dims if scan_dim and len(domain_dims) == 1: - # We construct the scan field operator only on the horizontal domain. - # If the field operator computes only the scan dimension, + # We construct the scan field operator on the horizontal domain, while the + # vertical dimension (the column axis) is computed by the loop region. + # If the field operator computes only the column axis (a 1d scan field operator), # there is no horizontal domain, therefore the map scope is not needed. + # This case currently triggers a DaCe issue and produces wrong CUDA code, + # thus the corresponding test is disabled (see pytest marker `uses_scan_1d_field`). map_entry, map_exit = (None, None) else: # create map range corresponding to the field operator domain