Skip to content

Commit

Permalink
Add support for ListType as scan output
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Jan 17, 2025
1 parent ca246d6 commit 8fb0c09
Showing 1 changed file with 64 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -310,27 +310,32 @@ def _create_field_operator_impl(
domain_subset = dace_subsets.Range.from_indices(domain_indices)

if isinstance(output_edge.result.gt_dtype, ts.ScalarType):
assert output_edge.result.gt_dtype == output_type.dtype
field_dtype = output_edge.result.gt_dtype
field_dims, field_shape, field_offset = (domain_dims, domain_shape, domain_offset)
if output_edge.result.gt_dtype != output_type.dtype:
raise TypeError(
f"Type mismatch, expected {output_type.dtype} got {output_edge.result.gt_dtype}."
)
element_type = output_edge.result.gt_dtype
field_shape, field_offset = (domain_shape, domain_offset)
assert isinstance(dataflow_output_desc, dace.data.Scalar)
field_subset = domain_subset
else:
assert isinstance(output_type.dtype, ts.ListType)
assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType)
assert output_edge.result.gt_dtype.element_type == output_type.dtype.element_type
field_dtype = output_edge.result.gt_dtype.element_type
assert output_edge.result.gt_dtype.offset_type is not None
element_type = output_edge.result.gt_dtype.element_type
if element_type != output_type.dtype.element_type:
raise TypeError(
f"Type mismatch, expected {output_type.dtype.element_type} got {element_type}."
)
assert isinstance(dataflow_output_desc, dace.data.Array)
assert len(dataflow_output_desc.shape) == 1
# extend the array with the local dimensions added by the field operator (e.g. `neighbors`)
assert output_edge.result.gt_dtype.offset_type is not None
field_dims = [*domain_dims, output_edge.result.gt_dtype.offset_type]
field_shape = [*domain_shape, dataflow_output_desc.shape[0]]
field_offset = [*domain_offset, dataflow_output_desc.offset[0]]
field_subset = domain_subset + dace_subsets.Range.from_array(dataflow_output_desc)

# allocate local temporary storage
assert dataflow_output_desc.dtype == dace_utils.as_dace_type(field_dtype)
assert dataflow_output_desc.dtype == dace_utils.as_dace_type(element_type)
field_name, _ = sdfg_builder.add_temp_array(sdfg, field_shape, dataflow_output_desc.dtype)
field_node = state.add_access(field_name)

Expand All @@ -339,7 +344,7 @@ def _create_field_operator_impl(

return FieldopData(
field_node,
ts.FieldType(field_dims, field_dtype),
ts.FieldType(domain_dims, output_edge.result.gt_dtype),
offset=(field_offset if set(field_offset) != {0} else None),
)

Expand All @@ -358,44 +363,70 @@ def _create_scan_field_operator_impl(
Similar to `_create_scan_field_operator_impl()` but for scan field operators.
"""
dataflow_output_desc = output_edge.result.dc_node.desc(sdfg)
assert isinstance(dataflow_output_desc, dace.data.Array)

domain_dims, domain_offset, domain_shape = _get_field_layout(domain)
domain_indices = _get_domain_indices(domain_dims, domain_offset)
domain_subset = dace_subsets.Range.from_indices(domain_indices)

# the vertical dimension should not belong to the field operator domain
# but we need to write it to the output field
scan_dim_index = domain_dims.index(scan_dim)

domain_subset = (
dace_subsets.Range.from_indices(domain_indices[:scan_dim_index])
+ dace_subsets.Range.from_string(f"0:{dataflow_output_desc.shape[0]}")
+ dace_subsets.Range.from_indices(domain_indices[scan_dim_index + 1 :])
)

if isinstance(output_edge.result.gt_dtype, ts.ScalarType):
# the scan field operator produces a 1D vertical field
assert isinstance(dataflow_output_desc, dace.data.Array)
assert isinstance(output_type.dtype, ts.ScalarType)
if output_edge.result.gt_dtype != output_type.dtype:
raise TypeError(
f"Type mismatch, expected {output_type.dtype} got {output_edge.result.gt_dtype}."
)
element_type = output_edge.result.gt_dtype
field_shape, field_offset = (domain_shape, domain_offset)
# the scan field operator computes a column of scalar values
assert len(dataflow_output_desc.shape) == 1
assert output_edge.result.gt_dtype == output_type.dtype
field_dtype = output_edge.result.gt_dtype
field_dims, field_shape, field_offset = (domain_dims, domain_shape, domain_offset)
# the vertical dimension should not belong to the field operator domain
# but we need to write it to the output field
scan_dim_index = domain_dims.index(scan_dim)
field_subset = (
dace_subsets.Range(domain_subset[:scan_dim_index])
+ dace_subsets.Range.from_array(dataflow_output_desc)
+ dace_subsets.Range(domain_subset[scan_dim_index + 1 :])
)
field_subset = domain_subset
else:
raise NotImplementedError("List of values not supported in scan field operators.")
assert isinstance(output_type.dtype, ts.ListType)
assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType)
assert output_edge.result.gt_dtype.offset_type is not None
element_type = output_edge.result.gt_dtype.element_type
if element_type != output_type.dtype.element_type:
raise TypeError(
f"Type mismatch, expected {output_type.dtype.element_type} got {element_type}."
)
# the scan field operator computes a list of scalar values for each column level
assert len(dataflow_output_desc.shape) == 2
# extend the array with the local dimensions added by the field operator (e.g. `neighbors`)
field_shape = [*domain_shape, dataflow_output_desc.shape[1]]
field_offset = [*domain_offset, dataflow_output_desc.offset[1]]
field_subset = domain_subset + dace_subsets.Range.from_string(
f"0:{dataflow_output_desc.shape[1]}"
)

# allocate local temporary storage
assert dataflow_output_desc.dtype == dace_utils.as_dace_type(field_dtype)
assert dataflow_output_desc.dtype == dace_utils.as_dace_type(element_type)
field_name, field_desc = sdfg_builder.add_temp_array(
sdfg, field_shape, dataflow_output_desc.dtype
)
# the inner and outer strides have to match
scan_output_stride = field_desc.strides[scan_dim_index]
dataflow_output_desc.strides = (scan_output_stride,)
field_node = state.add_access(field_name)
# also consider the stride of the local dimension, in case the scan field operator computes a list
local_strides = field_desc.strides[len(domain_dims) :]
assert len(local_strides) == (1 if isinstance(output_edge.result.gt_dtype, ts.ListType) else 0)
new_inner_strides = [scan_output_stride, *local_strides]
dataflow_output_desc.set_shape(dataflow_output_desc.shape, new_inner_strides)

# and here the edge writing the dataflow result data through the map exit node
field_node = state.add_access(field_name)
output_edge.connect(map_exit, field_node, field_subset)

return FieldopData(
field_node,
ts.FieldType(field_dims, field_dtype),
ts.FieldType(domain_dims, output_edge.result.gt_dtype),
offset=(field_offset if set(field_offset) != {0} else None),
)

Expand Down Expand Up @@ -436,9 +467,12 @@ def _create_field_operator(

assert scan_dim is None or scan_dim in domain_dims
if scan_dim and len(domain_dims) == 1:
# We construct the scan field operator only on the horizontal domain.
# If the field operator computes only the scan dimension,
# We construct the scan field operator on the horizontal domain, while the
# vertical dimension (the column axis) is computed by the loop region.
# If the field operator computes only the column axis (a 1d scan field operator),
# there is no horizontal domain, therefore the map scope is not needed.
# This case currently triggers a DaCe issue and produces wrong CUDA code,
# thus the corresponding test is disabled (see pytest marker `uses_scan_1d_field`).
map_entry, map_exit = (None, None)
else:
# create map range corresponding to the field operator domain
Expand Down

0 comments on commit 8fb0c09

Please sign in to comment.