diff --git a/.github/workflows/general-ci.yml b/.github/workflows/general-ci.yml index 8d622f758f..eb14dbe224 100644 --- a/.github/workflows/general-ci.yml +++ b/.github/workflows/general-ci.yml @@ -55,7 +55,7 @@ jobs: else export DACE_optimizer_automatic_simplification=${{ matrix.simplify }} fi - pytest -n auto --cov-report=xml --cov=dace --tb=short -m "not gpu and not verilator and not tensorflow and not mkl and not sve and not papi and not mlir and not lapack and not fpga and not mpi and not rtl_hardware and not scalapack and not datainstrument and not long" + pytest -n auto --cov-report=xml --cov=dace --tb=short -m "not gpu and not verilator and not tensorflow and not mkl and not sve and not papi and not mlir and not lapack and not fpga and not mpi and not rtl_hardware and not scalapack and not datainstrument and not long" tests/numpy ./codecov - name: Test OpenBLAS LAPACK diff --git a/dace/subsets.py b/dace/subsets.py index 0fdc36c22e..9c79e7d7d1 100644 --- a/dace/subsets.py +++ b/dace/subsets.py @@ -80,7 +80,7 @@ def covers(self, other): # Subsets of different dimensionality can never cover each other. if self.dims() != other.dims(): return ValueError( - f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}" + f"A subset of dimensionality {self.dims()} cannot test covering a subset of dimensionality {other.dims()}" ) if not Config.get('optimizer', 'symbolic_positive'): @@ -106,7 +106,7 @@ def covers_precise(self, other): # Subsets of different dimensionality can never cover each other. if self.dims() != other.dims(): return ValueError( - f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}" + f"A subset of dimensionality {self.dims()} cannot test covering a subset of dimensionality {other.dims()}" ) # If self does not cover other with a bounding box union, return false. diff --git a/dace/transformation/auto/auto_optimize.py b/dace/transformation/auto/auto_optimize.py index 0c74842634..024776f9de 100644 --- a/dace/transformation/auto/auto_optimize.py +++ b/dace/transformation/auto/auto_optimize.py @@ -14,11 +14,12 @@ import warnings # Transformations -from dace.transformation.dataflow import MapCollapse, TrivialMapElimination, MapFusion, ReduceExpansion +from dace.transformation.passes import FullMapFusion +from dace.transformation.dataflow import MapCollapse, TrivialMapElimination, ReduceExpansion from dace.transformation.interstate import LoopToMap, RefineNestedAccess from dace.transformation.subgraph.composite import CompositeFusion from dace.transformation.subgraph import helpers as xfsh -from dace.transformation import helpers as xfh +from dace.transformation import helpers as xfh, pass_pipeline as ppl # Environments from dace.libraries.blas.environments import intel_mkl as mkl, openblas @@ -57,8 +58,13 @@ def greedy_fuse(graph_or_subgraph: GraphViewType, if isinstance(graph_or_subgraph, SDFG): # If we have an SDFG, recurse into graphs graph_or_subgraph.simplify(validate_all=validate_all) - # MapFusion for trivial cases - graph_or_subgraph.apply_transformations_repeated(MapFusion, validate_all=validate_all) + # Apply MapFusion for the more trivial cases + full_map_fusion_pass = FullMapFusion( + strict_dataflow=True, + validate_all=validate_all, + ) + full_map_fusion_pileline = ppl.Pipeline([full_map_fusion_pass]) + full_map_fusion_pileline.apply_pass(graph_or_subgraph, {}) # recurse into graphs for graph in graph_or_subgraph.nodes(): @@ -76,7 +82,13 @@ def greedy_fuse(graph_or_subgraph: GraphViewType, sdfg, graph, subgraph = None, None, None if isinstance(graph_or_subgraph, SDFGState): sdfg = graph_or_subgraph.parent - sdfg.apply_transformations_repeated(MapFusion, validate_all=validate_all) + # Apply MapFusion for the more trivial cases + full_map_fusion_pass = FullMapFusion( + strict_dataflow=True, + validate_all=validate_all, + ) + full_map_fusion_pileline = ppl.Pipeline([full_map_fusion_pass]) + full_map_fusion_pileline.apply_pass(sdfg, {}) graph = graph_or_subgraph subgraph = SubgraphView(graph, graph.nodes()) else: diff --git a/dace/transformation/dataflow/buffer_tiling.py b/dace/transformation/dataflow/buffer_tiling.py index a418e167d8..af966d8a32 100644 --- a/dace/transformation/dataflow/buffer_tiling.py +++ b/dace/transformation/dataflow/buffer_tiling.py @@ -98,7 +98,13 @@ def apply(self, graph, sdfg): # Fuse maps some_buffer = next(iter(buffers)) # some dummy to pass to MapFusion.apply_to() - MapFusion.apply_to(sdfg, first_map_exit=tile_map1_exit, array=some_buffer, second_map_entry=tile_map2_entry) + MapFusion.apply_to( + sdfg, + first_map_exit=tile_map1_exit, + array=some_buffer, + second_map_entry=tile_map2_entry, + verify=True, + ) # Optimize the simple cases map1_entry.range.ranges = [ diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index a6762d45c4..fa32b7240c 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -1,537 +1,1789 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -""" This module contains classes that implement the map fusion transformation. -""" - -from copy import deepcopy as dcpy -from dace.sdfg.sdfg import SDFG -from dace.sdfg.state import SDFGState -from dace import data, dtypes, symbolic, subsets -from dace.sdfg import nodes -from dace.memlet import Memlet -from dace.sdfg import replace -from dace.sdfg import utils as sdutil -from dace.transformation import transformation -from typing import List, Union -import networkx as nx +"""Implements the serial map fusing transformation.""" +import copy +from typing import Any, Dict, List, Optional, Set, Tuple, Union, Iterable +import dace +from dace import data, dtypes, properties, subsets, symbolic, transformation +from dace.sdfg import SDFG, SDFGState, graph, nodes, validation +from dace.transformation import helpers + + +@properties.make_properties class MapFusion(transformation.SingleStateTransformation): - """ Implements the MapFusion transformation. - It wil check for all patterns MapExit -> AccessNode -> MapEntry, and - based on the following rules, fuse them and remove the transient in - between. There are several possibilities of what it does to this - transient in between. - - Essentially, if there is some other place in the - sdfg where it is required, or if it is not a transient, then it will - not be removed. In such a case, it will be linked to the MapExit node - of the new fused map. - - Rules for fusing maps: - 0. The map range of the second map should be a permutation of the - first map range. - 1. Each of the access nodes that are adjacent to the first map exit - should have an edge to the second map entry. If it doesn't, then the - second map entry should not be reachable from this access node. - 2. Any node that has a wcr from the first map exit should not be - adjacent to the second map entry. - 3. Access pattern for the access nodes in the second map should be - the same permutation of the map parameters as the map ranges of the - two maps. Alternatively, this access node should not be adjacent to - the first map entry. + """Implements the MapFusion transformation. + + From a high level perspective it will remove the MapExit node of the first and the MapEntry node of + the second Map. It will then rewire and modify the Memlets such that the data flow bypasses the + intermediate node. For this a new intermediate node will be created, which is much smaller because + it has no longer to store the whole output of the first map, but only the data that is produced by + a single iteration of the first map. The transformation will then remove the old intermediate. + Thus by merging the two Maps together the transformation will reduce the memory footprint. It is + important that it is not always possible to fully remove the intermediate node. For example the + data might be used somewhere else. In this case the intermediate will become an output of the Map. + + An example would be the following: + ```python + for i in range(N): + T[i] = foo(A[i]) + for j in range(N): + B[j] = bar(T[i]) + ``` + which would be translated into: + ```python + for i in range(N): + temp: scalar = foo(A[i]) + B[i] = bar(temp) + ``` + + The checks that two Maps can be fused are quite involved, however, they essentially check: + * If the two Maps cover the same iteration space, essentially have the same start, stop and + iteration , see `find_parameter_remapping()`. + * Furthermore, they verify if the new fused Map did not introduce read write conflict, + essentially it tests if the data is pointwise, i.e. what is read is also written, + see `has_read_write_dependency()`. + * Then it will examine the intermediate data. This will essentially test if the data that + is needed by a single iteration of the second Map is produced by a single iteration of + the first Map, see `partition_first_outputs()`. + + By default `strict_dataflow` is enabled. In this mode the transformation is more conservative. + The main difference is, that it will not adjust the subsets of the intermediate, i.e. turning + an array with shape `(1, 1, 1, 1)` into a scalar. Furthermore, shared intermediates, see + `partition_first_outputs()` will only be created if the data is not referred downstream in + the dataflow. + + In order to determine if an intermediate can be removed or has to be kept, it is in general + necessary to scan the whole SDFG, which is the default behaviour. There are two ways to + speed this up. The first way is to set `assume_always_shared` to `True`. In this case the + transformation will not perform the scan, but assume that the data is shared, i.e. used + somewhere else. This might lead to dead data flow. + The second way is to use the transformation inside a pipeline, which includes the + `FindSingleUseData` analysis pass. If the result of this pass is present then the + transformation will use it instead to determine if a intermediate can be removed. + Note that `assume_always_shared` takes precedence. + For this pattern the `FullMapFusion` pass is provided, that combines the analysis + pass and `MapFusion` + + :param only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + :param only_toplevel_maps: Only consider Maps that are at the top. + :param strict_dataflow: Which dataflow mode should be used, see above. + :param assume_always_shared: Assume that all intermediates are shared. + + :note: This transformation modifies more nodes than it matches. + :note: If `assume_always_shared` is `True` then the transformation will assume that + all intermediates are shared. This avoids the problems mentioned above with + the cache at the expense of the creation of dead dataflow. """ - first_map_exit = transformation.PatternNode(nodes.ExitNode) - array = transformation.PatternNode(nodes.AccessNode) - second_map_entry = transformation.PatternNode(nodes.EntryNode) - @staticmethod - def annotates_memlets(): - return False + # Pattern Nodes + first_map_exit = transformation.transformation.PatternNode(nodes.MapExit) + array = transformation.transformation.PatternNode(nodes.AccessNode) + second_map_entry = transformation.transformation.PatternNode(nodes.MapEntry) + + + # Settings + only_toplevel_maps = properties.Property( + dtype=bool, + default=False, + desc="Only perform fusing if the Maps are in the top level.", + ) + only_inner_maps = properties.Property( + dtype=bool, + default=False, + desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", + ) + strict_dataflow = properties.Property( + dtype=bool, + default=True, + desc="If `True` then the transformation will ensure a more stricter data flow.", + ) + assume_always_shared = properties.Property( + dtype=bool, + default=False, + desc="If `True` then all intermediates will be classified as shared.", + ) + + + def __init__( + self, + only_inner_maps: Optional[bool] = None, + only_toplevel_maps: Optional[bool] = None, + strict_dataflow: Optional[bool] = None, + assume_always_shared: Optional[bool] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if only_toplevel_maps is not None: + self.only_toplevel_maps = only_toplevel_maps + if only_inner_maps is not None: + self.only_inner_maps = only_inner_maps + if strict_dataflow is not None: + self.strict_dataflow = strict_dataflow + if assume_always_shared is not None: + self.assume_always_shared = assume_always_shared + + # See comment in `is_shared_data()` for more information. + self._single_use_data: Optional[Dict[dace.SDFG, Set[str]]] = None + @classmethod - def expressions(cls): - return [sdutil.node_path_graph(cls.first_map_exit, cls.array, cls.second_map_entry)] - - @staticmethod - def find_permutation(first_map: nodes.Map, second_map: nodes.Map) -> Union[List[int], None]: - """ Find permutation between two map ranges. - - :param first_map: First map. - :param second_map: Second map. - :return: None if no such permutation exists, otherwise a list of - indices L such that L[x]'th parameter of second map has the same range as x'th - parameter of the first map. + def expressions(cls) -> Any: + """Get the match expression. + + The transformation matches the exit node of the top Map that is connected to + an access node that again is connected to the entry node of the second Map. + An important note is, that the transformation operates not just on the + matched nodes, but more or less on anything that has an incoming connection + from the first Map or an outgoing connection to the second Map entry. + """ + return [dace.sdfg.utils.node_path_graph(cls.first_map_exit, cls.array, cls.second_map_entry)] + + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Tests if the matched Maps can be merged. + + The two Maps are mergeable iff: + * Checks general requirements, see `can_topologically_be_fused()`. + * Tests if there are read write dependencies. + * Tests if the decomposition exists. """ - result = [] + # To ensures that the `{src,dst}_subset` are properly set, run initialization. + # See [issue 1708](https://github.com/spcl/dace/issues/1703) + for edge in graph.edges(): + edge.data.try_initialize(sdfg, graph, edge) + + first_map_entry: nodes.MapEntry = graph.entry_node(self.first_map_exit) + first_map_exit: nodes.MapExit = self.first_map_exit + second_map_entry: nodes.MapEntry = self.second_map_entry + + # Check the structural properties of the Maps. The function will return + # the `dict` that describes how the parameters must be renamed (for caching) + # or `None` if the maps can not be structurally fused. + param_repl = self.can_topologically_be_fused( + first_map_entry=first_map_entry, + second_map_entry=second_map_entry, + graph=graph, + sdfg=sdfg + ) + if param_repl is None: + return False - if len(first_map.range) != len(second_map.range): - return None + # Tests if there are read write dependencies that are caused by the bodies + # of the Maps, such as referring to the same data. Note that this tests are + # different from the ones performed by `has_read_write_dependency()`, which + # only checks the data dependencies that go through the scope nodes. + if self.has_inner_read_write_dependency( + first_map_entry=first_map_entry, + second_map_entry=second_map_entry, + state=graph, + sdfg=sdfg, + ): + return False - # Match map ranges with reduce ranges - for i, tmap_rng in enumerate(first_map.range): - found = False - for j, rng in enumerate(second_map.range): - if tmap_rng == rng and j not in result: - result.append(j) - found = True - break - if not found: - break + # Tests for read write conflicts of the two maps, this is only checking + # the data that goes through the scope nodes. `has_inner_read_write_dependency()` + # if used to check if there are internal dependencies. + if self.has_read_write_dependency( + first_map_entry=first_map_entry, + second_map_entry=second_map_entry, + param_repl=param_repl, + state=graph, + sdfg=sdfg, + ): + return False - # Ensure all map ranges matched - if len(result) != len(first_map.range): - return None + # Two maps can be serially fused if the node decomposition exists and + # at least one of the intermediate output sets is not empty. The state + # of the pure outputs is irrelevant for serial map fusion. + output_partition = self.partition_first_outputs( + state=graph, + sdfg=sdfg, + first_map_exit=first_map_exit, + second_map_entry=second_map_entry, + param_repl=param_repl, + ) + if output_partition is None: + return False + _, exclusive_outputs, shared_outputs = output_partition + if not (exclusive_outputs or shared_outputs): + return False + + return True + + + def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: + """Performs the serial Map fusing. + + The function first computes the map decomposition and then handles the + three sets. The pure outputs are handled by `relocate_nodes()` while + the two intermediate sets are handled by `handle_intermediate_set()`. + + By assumption we do not have to rename anything. + + :param graph: The SDFG state we are operating on. + :param sdfg: The SDFG we are operating on. + """ + # To ensures that the `{src,dst}_subset` are properly set, run initialization. + # See [issue 1708](https://github.com/spcl/dace/issues/1703) + for edge in graph.edges(): + edge.data.try_initialize(sdfg, graph, edge) + + first_map_exit: nodes.MapExit = self.first_map_exit + second_map_entry: nodes.MapEntry = self.second_map_entry + second_map_exit: nodes.MapExit = graph.exit_node(self.second_map_entry) + first_map_entry: nodes.MapEntry = graph.entry_node(self.first_map_exit) + + # Before we do anything we perform the renaming. + self.rename_map_parameters( + first_map=first_map_exit.map, + second_map=second_map_entry.map, + second_map_entry=second_map_entry, + state=graph, + ) + + # Now compute the partition. Because we have already renamed the parameters + # of the second Map, there is no need to perform any renaming, thus we can + # pass an empty `dict`. + output_partition = self.partition_first_outputs( + state=graph, + sdfg=sdfg, + first_map_exit=first_map_exit, + second_map_entry=second_map_entry, + param_repl=dict(), + ) + assert output_partition is not None # Make MyPy happy. + pure_outputs, exclusive_outputs, shared_outputs = output_partition + + # Now perform the actual rewiring, we handle each partition separately. + if len(exclusive_outputs) != 0: + self.handle_intermediate_set( + intermediate_outputs=exclusive_outputs, + state=graph, + sdfg=sdfg, + first_map_exit=first_map_exit, + second_map_entry=second_map_entry, + second_map_exit=second_map_exit, + is_exclusive_set=True, + ) + if len(shared_outputs) != 0: + self.handle_intermediate_set( + intermediate_outputs=shared_outputs, + state=graph, + sdfg=sdfg, + first_map_exit=first_map_exit, + second_map_entry=second_map_entry, + second_map_exit=second_map_exit, + is_exclusive_set=False, + ) + assert pure_outputs == set(graph.out_edges(first_map_exit)) + if len(pure_outputs) != 0: + self.relocate_nodes( + from_node=first_map_exit, + to_node=second_map_exit, + state=graph, + sdfg=sdfg, + ) + + # Now move the input of the second map, that has no connection to the first + # map, to the first map. This is needed because we will later delete the + # exit of the first map (which we have essentially handled above). Now + # we must handle the input of the second map (that has no connection to the + # first map) to the input of the first map. + self.relocate_nodes( + from_node=second_map_entry, + to_node=first_map_entry, + state=graph, + sdfg=sdfg, + ) + + for node_to_remove in [first_map_exit, second_map_entry]: + assert graph.degree(node_to_remove) == 0 + graph.remove_node(node_to_remove) + + # Now turn the second output node into the output node of the first Map. + second_map_exit.map = first_map_entry.map + + + def partition_first_outputs( + self, + state: SDFGState, + sdfg: SDFG, + first_map_exit: nodes.MapExit, + second_map_entry: nodes.MapEntry, + param_repl: Dict[str, str], + ) -> Union[ + Tuple[ + Set[graph.MultiConnectorEdge[dace.Memlet]], + Set[graph.MultiConnectorEdge[dace.Memlet]], + Set[graph.MultiConnectorEdge[dace.Memlet]], + ], + None, + ]: + """Partition the output edges of `first_map_exit` for serial map fusion. + + The output edges of the first map are partitioned into three distinct sets, + defined as follows: + * Pure Output Set `\mathbb{P}`: + These edges exits the first map and does not enter the second map. These + outputs will be simply be moved to the output of the second map. + * Exclusive Intermediate Set `\mathbb{E}`: + Edges in this set leaves the first map exit, enters an access node, from + where a Memlet then leads immediately to the second map. The memory + referenced by this access node is not used anywhere else, thus it can + be removed. + * Shared Intermediate Set `\mathbb{S}`: + These edges are very similar to the one in `\mathbb{E}` except that they + are used somewhere else, thus they can not be removed and have to be + recreated as output of the second map. + + If strict data flow mode is enabled the function is rather strict if an + output can be added to either intermediate set and might fail to compute + the partition, even if it would exist. + + :return: If such a decomposition exists the function will return the three sets + mentioned above in the same order. In case the decomposition does not exist, + i.e. the maps can not be fused the function returns `None`. + + :param state: The in which the two maps are located. + :param sdfg: The full SDFG in whcih we operate. + :param first_map_exit: The exit node of the first map. + :param second_map_entry: The entry node of the second map. + :param param_repl: Use this map to rename the parameter of the second Map, such + that they match the one of the first map. + """ + # The three outputs set. + pure_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + exclusive_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + shared_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + + # Set of intermediate nodes that we have already processed. + processed_inter_nodes: Set[nodes.Node] = set() + + # Now scan all output edges of the first exit and classify them + for out_edge in state.out_edges(first_map_exit): + intermediate_node: nodes.Node = out_edge.dst + + # We already processed the node, this should indicate that we should + # run simplify again, or we should start implementing this case. + # TODO(phimuell): Handle this case, already partially handled here. + if intermediate_node in processed_inter_nodes: + return None + processed_inter_nodes.add(intermediate_node) + + # The intermediate can only have one incoming degree. It might be possible + # to handle multiple incoming edges, if they all come from the top map. + # However, the resulting SDFG might be invalid. + # NOTE: Allow this to happen (under certain cases) if the only producer + # is the top map. + if state.in_degree(intermediate_node) != 1: + return None + + # If the second map is not reachable from the intermediate node, then + # the output is pure and we can end here. + if not self.is_node_reachable_from( + graph=state, + begin=intermediate_node, + end=second_map_entry, + ): + pure_outputs.add(out_edge) + continue + + # The following tests are _after_ we have determined if we have a pure + # output node, because this allows us to handle more exotic pure node + # cases, as handling them is essentially rerouting an edge, whereas + # handling intermediate nodes is much more complicated. + + # Empty Memlets are only allowed if they are in `\mathbb{P}`, which + # is also the only place they really make sense (for a map exit). + # Thus if we now found an empty Memlet we reject it. + if out_edge.data.is_empty(): + return None + + # For us an intermediate node must always be an access node, because + # everything else we do not know how to handle. It is important that + # we do not test for non transient data here, because they can be + # handled has shared intermediates. + if not isinstance(intermediate_node, nodes.AccessNode): + return None + intermediate_desc: dace.data.Data = intermediate_node.desc(sdfg) + if self.is_view(intermediate_desc, sdfg): + return None + + # It can happen that multiple edges converges at the `IN_` connector + # of the first map exit, but there is only one edge leaving the exit. + # It is complicate to handle this, so for now we ignore it. + # TODO(phimuell): Handle this case properly. + # To handle this we need to associate a consumer edge (the outgoing edges + # of the second map) with exactly one producer. + producer_edges: List[graph.MultiConnectorEdge[dace.Memlet]] = list( + state.in_edges_by_connector(first_map_exit, "IN_" + out_edge.src_conn[4:])) + if len(producer_edges) > 1: + return None + + # Now check the constraints we have on the producers. + # - The source of the producer can not be a view (we do not handle this) + # - The edge shall also not be a reduction edge. + # - Defined location to where they write. + # - No dynamic Melets. + # Furthermore, we will also extract the subsets, i.e. the location they + # modify inside the intermediate array. + # Since we do not allow for WCR, we do not check if the producer subsets intersects. + producer_subsets: List[subsets.Subset] = [] + for producer_edge in producer_edges: + if isinstance(producer_edge.src, nodes.AccessNode) and self.is_view(producer_edge.src, sdfg): + return None + if producer_edge.data.dynamic: + # TODO(phimuell): Find out if this restriction could be lifted, but it is unlikely. + return None + if producer_edge.data.wcr is not None: + return None + if producer_edge.data.dst_subset is None: + return None + producer_subsets.append(producer_edge.data.dst_subset) + + # Check if the producer do not intersect + if len(producer_subsets) == 1: + pass + elif len(producer_subsets) == 2: + if producer_subsets[0].intersects(producer_subsets[1]): + return None + else: + for i, psbs1 in enumerate(producer_subsets): + for j, psbs2 in enumerate(producer_subsets): + if i == j: + continue + if psbs1.intersects(psbs2): + return None + + # Now we determine the consumer of nodes. For this we are using the edges + # leaves the second map entry. It is not necessary to find the actual + # consumer nodes, as they might depend on symbols of nested Maps. + # For the covering test we only need their subsets, but we will perform + # some scan and filtering on them. + found_second_map = False + consumer_subsets: List[subsets.Subset] = [] + for intermediate_node_out_edge in state.out_edges(intermediate_node): + + # If the second map entry is not immediately reachable from the intermediate + # node, then ensure that there is not path that goes to it. + if intermediate_node_out_edge.dst is not second_map_entry: + if self.is_node_reachable_from(graph=state, begin=intermediate_node_out_edge.dst, end=second_map_entry): + return None + continue + + # Ensure that the second map is found exactly once. + # TODO(phimuell): Lift this restriction. + if found_second_map: + return None + found_second_map = True + + # The output of the top map can not define a dynamic map range in the + # second map. + if not intermediate_node_out_edge.dst_conn.startswith("IN_"): + return None + + # Now we look at all edges that leave the second map entry, i.e. the + # edges that feeds the consumer and define what is read inside the map. + # We do not check them, but collect them and inspect them. + # NOTE1: The subset still uses the old iteration variables. + # NOTE2: In case of consumer Memlet we explicitly allow dynamic Memlets. + # This is different compared to the producer Memlet. The reason is + # because in a consumer the data is conditionally read, so the data + # has to exists anyway. + for inner_consumer_edge in state.out_edges_by_connector( + second_map_entry, "OUT_" + intermediate_node_out_edge.dst_conn[3:]): + if inner_consumer_edge.data.src_subset is None: + return None + consumer_subsets.append(inner_consumer_edge.data.src_subset) + assert found_second_map, f"Found '{intermediate_node}' which looked like a pure node, but is not one." + assert len(consumer_subsets) != 0 + + # The consumer still uses the original symbols of the second map, so we must rename them. + if param_repl: + consumer_subsets = copy.deepcopy(consumer_subsets) + for consumer_subset in consumer_subsets: + symbolic.safe_replace(mapping=param_repl, replace_callback=consumer_subset.replace) + + # Now we are checking if a single iteration of the first (top) map + # can satisfy all data requirements of the second (bottom) map. + # For this we look if the producer covers the consumer. A consumer must + # be covered by exactly one producer. + for consumer_subset in consumer_subsets: + nb_coverings = sum(producer_subset.covers(consumer_subset) for producer_subset in producer_subsets) + if nb_coverings != 1: + return None + + # After we have ensured coverage, we have to decide if the intermediate + # node can be removed (`\mathbb{E}`) or has to be restored (`\mathbb{S}`). + # Note that "removed" here means that it is reconstructed by a new + # output of the second map. + if self.is_shared_data(data=intermediate_node, state=state, sdfg=sdfg): + # The intermediate data is used somewhere else, either in this or another state. + # NOTE: If the intermediate is shared, then we will turn it into a + # sink node attached to the combined map exit. Technically this + # should be enough, even if the same data appears again in the + # dataflow down streams. However, some DaCe transformations, + # I am looking at you `auto_optimizer()` do not like that. Thus + # if the intermediate is used further down in the same datadflow, + # then we consider that the maps can not be fused. But we only + # do this in the strict data flow mode. + if self.strict_dataflow: + if self._is_data_accessed_downstream( + data=intermediate_node.data, + graph=state, + begin=intermediate_node, # is ignored itself. + ): + return None + shared_outputs.add(out_edge) + else: + # The intermediate can be removed, as it is not used anywhere else. + exclusive_outputs.add(out_edge) + + assert len(processed_inter_nodes) == sum(len(x) for x in [pure_outputs, exclusive_outputs, shared_outputs]) + return (pure_outputs, exclusive_outputs, shared_outputs) + + + def relocate_nodes( + self, + from_node: Union[nodes.MapExit, nodes.MapEntry], + to_node: Union[nodes.MapExit, nodes.MapEntry], + state: SDFGState, + sdfg: SDFG, + ) -> None: + """Move the connectors and edges from `from_node` to `to_nodes` node. + + This function will only rewire the edges, it does not remove the nodes + themselves. Furthermore, this function should be called twice per Map, + once for the entry and then for the exit. + While it does not remove the node themselves if guarantees that the + `from_node` has degree zero. + The function assumes that the parameter renaming was already done. + + :param from_node: Node from which the edges should be removed. + :param to_node: Node to which the edges should reconnect. + :param state: The state in which the operation happens. + :param sdfg: The SDFG that is modified. + """ + + # Now we relocate empty Memlets, from the `from_node` to the `to_node` + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.out_edges(from_node))): + helpers.redirect_edge(state, empty_edge, new_src=to_node) + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.in_edges(from_node))): + helpers.redirect_edge(state, empty_edge, new_dst=to_node) + + # We now ensure that there is only one empty Memlet from the `to_node` to any other node. + # Although it is allowed, we try to prevent it. + empty_targets: Set[nodes.Node] = set() + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.all_edges(to_node))): + if empty_edge.dst in empty_targets: + state.remove_edge(empty_edge) + empty_targets.add(empty_edge.dst) + + # We now determine which edges we have to migrate, for this we are looking at + # the incoming edges, because this allows us also to detect dynamic map ranges. + # TODO(phimuell): If there is already a connection to the node, reuse this. + for edge_to_move in list(state.in_edges(from_node)): + assert isinstance(edge_to_move.dst_conn, str) + + if not edge_to_move.dst_conn.startswith("IN_"): + # Dynamic Map Range + # The connector name simply defines a variable name that is used, + # inside the Map scope to define a variable. We handle it directly. + dmr_symbol = edge_to_move.dst_conn + + # TODO(phimuell): Check if the symbol is really unused in the target scope. + if dmr_symbol in to_node.in_connectors: + raise NotImplementedError(f"Tried to move the dynamic map range '{dmr_symbol}' from {from_node}'" + f" to '{to_node}', but the symbol is already known there, but the" + " renaming is not implemented.") + if not to_node.add_in_connector(dmr_symbol, force=False): + raise RuntimeError( # Might fail because of out connectors. + f"Failed to add the dynamic map range symbol '{dmr_symbol}' to '{to_node}'.") + helpers.redirect_edge(state=state, edge=edge_to_move, new_dst=to_node) + from_node.remove_in_connector(dmr_symbol) - return result - - def can_be_applied(self, graph, expr_index, sdfg: SDFG, permissive=False): - first_map_exit = self.first_map_exit - first_map_entry = graph.entry_node(first_map_exit) - second_map_entry = self.second_map_entry - second_map_exit = graph.exit_node(second_map_entry) - - for _in_e in graph.in_edges(first_map_exit): - if _in_e.data.wcr is not None: - for _out_e in graph.out_edges(second_map_entry): - if _out_e.data.data == _in_e.data.data: - # wcr is on a node that is used in the second map, quit - return False - # Check whether there is a pattern map -> access -> map. - intermediate_nodes = set() - intermediate_data = set() - for _, _, dst, _, _ in graph.out_edges(first_map_exit): - if isinstance(dst, nodes.AccessNode): - intermediate_nodes.add(dst) - intermediate_data.add(dst.data) - - # If array is used anywhere else in this state. - num_occurrences = len([n for n in sdfg.data_nodes() if n.data == dst.data]) - if num_occurrences > 1: - return False else: + # We have a Passthrough connection, i.e. there exists a matching `OUT_`. + old_conn = edge_to_move.dst_conn[3:] # The connection name without prefix + new_conn = to_node.next_connector(old_conn) + + to_node.add_in_connector("IN_" + new_conn) + for e in list(state.in_edges_by_connector(from_node, "IN_" + old_conn)): + helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) + to_node.add_out_connector("OUT_" + new_conn) + for e in list(state.out_edges_by_connector(from_node, "OUT_" + old_conn)): + helpers.redirect_edge(state, e, new_src=to_node, new_src_conn="OUT_" + new_conn) + from_node.remove_in_connector("IN_" + old_conn) + from_node.remove_out_connector("OUT_" + old_conn) + + # Check if we succeeded. + if state.out_degree(from_node) != 0: + raise validation.InvalidSDFGError( + f"Failed to relocate the outgoing edges from `{from_node}`, there are still `{state.out_edges(from_node)}`", + sdfg, + sdfg.node_id(state), + ) + if state.in_degree(from_node) != 0: + raise validation.InvalidSDFGError( + f"Failed to relocate the incoming edges from `{from_node}`, there are still `{state.in_edges(from_node)}`", + sdfg, + sdfg.node_id(state), + ) + assert len(from_node.in_connectors) == 0 + assert len(from_node.out_connectors) == 0 + + + def handle_intermediate_set( + self, + intermediate_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]], + state: SDFGState, + sdfg: SDFG, + first_map_exit: nodes.MapExit, + second_map_entry: nodes.MapEntry, + second_map_exit: nodes.MapExit, + is_exclusive_set: bool, + ) -> None: + """This function handles the intermediate sets. + + The function is able to handle both the shared and exclusive intermediate + output set, see `partition_first_outputs()`. The main difference is that + in exclusive mode the intermediate nodes will be fully removed from + the SDFG. While in shared mode the intermediate node will be preserved. + The function assumes that the parameter renaming was already done. + + :param intermediate_outputs: The set of outputs, that should be processed. + :param state: The state in which the map is processed. + :param sdfg: The SDFG that should be optimized. + :param first_map_exit: The exit of the first/top map. + :param second_map_entry: The entry of the second map. + :param second_map_exit: The exit of the second map. + :param is_exclusive_set: If `True` `intermediate_outputs` is the exclusive set. + + :note: Before the transformation the `state` does not have to be valid and + after this function has run the state is (most likely) invalid. + """ + + map_params = first_map_exit.map.params.copy() + + # Now we will iterate over all intermediate edges and process them. + # If not stated otherwise the comments assume that we run in exclusive mode. + for out_edge in intermediate_outputs: + # This is the intermediate node that, that we want to get rid of. + # In shared mode we want to recreate it after the second map. + inter_node: nodes.AccessNode = out_edge.dst + inter_name = inter_node.data + inter_desc = inter_node.desc(sdfg) + inter_shape = inter_desc.shape + + # Now we will determine the shape of the new intermediate. This size of + # this temporary is given by the Memlet that goes into the first map exit. + pre_exit_edges = list(state.in_edges_by_connector(first_map_exit, "IN_" + out_edge.src_conn[4:])) + if len(pre_exit_edges) != 1: + raise NotImplementedError() + pre_exit_edge = pre_exit_edges[0] + + (new_inter_shape_raw, new_inter_shape, squeezed_dims) = self.compute_reduced_intermediate( + producer_subset=pre_exit_edge.data.dst_subset, + inter_desc=inter_desc, + ) + + # This is the name of the new "intermediate" node that we will create. + # It will only have the shape `new_inter_shape` which is basically its + # output within one Map iteration. + # NOTE: The insertion process might generate a new name. + new_inter_name: str = f"__s{self.state_id}_n{state.node_id(out_edge.src)}{out_edge.src_conn}_n{state.node_id(out_edge.dst)}{out_edge.dst_conn}" + + # Now generate the intermediate data container. + if len(new_inter_shape) == 0: + assert pre_exit_edge.data.subset.num_elements() == 1 + is_scalar = True + new_inter_name, new_inter_desc = sdfg.add_scalar( + new_inter_name, + dtype=inter_desc.dtype, + transient=True, + find_new_name=True, + ) + + else: + assert (pre_exit_edge.data.subset.num_elements() > 1) or all(x == 1 for x in new_inter_shape) + is_scalar = False + new_inter_name, new_inter_desc = sdfg.add_transient( + new_inter_name, + shape=new_inter_shape, + dtype=inter_desc.dtype, + find_new_name=True, + ) + new_inter_node: nodes.AccessNode = state.add_access(new_inter_name) + + # Get the subset that defined into which part of the old intermediate + # the old output edge wrote to. We need that to adjust the producer + # Memlets, since they now write into the new (smaller) intermediate. + producer_offset = self.compute_offset_subset( + original_subset=pre_exit_edge.data.dst_subset, + intermediate_desc=inter_desc, + map_params=map_params, + producer_offset=None, + ) + + # Memlets have a lot of additional informations, to ensure that we get + # all of them, we have to do it this way. The main reason for this is + # to handle the case were the "Memlet reverse direction", i.e. `data` + # refers to the other end of the connection than before. + assert pre_exit_edge.data.dst_subset is not None + new_pre_exit_memlet_src_subset = copy.deepcopy(pre_exit_edge.data.src_subset) + new_pre_exit_memlet_dst_subset = subsets.Range.from_array(new_inter_desc) + + new_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) + new_pre_exit_memlet.data = new_inter_name + + new_pre_exit_edge = state.add_edge( + pre_exit_edge.src, + pre_exit_edge.src_conn, + new_inter_node, + None, + new_pre_exit_memlet, + ) + + # We can update `{src, dst}_subset` only after we have inserted the + # edge, this is because the direction of the Memlet might change. + new_pre_exit_edge.data.src_subset = new_pre_exit_memlet_src_subset + new_pre_exit_edge.data.dst_subset = new_pre_exit_memlet_dst_subset + + # We now handle the MemletTree defined by this edge. + # The newly created edge, only handled the last collection step. + for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children(include_self=False): + producer_edge = producer_tree.edge + + # In order to preserve the intrinsic direction of Memlets we only have to change + # the `.data` attribute of the producer Memlet if it refers to the old intermediate. + # If it refers to something different we keep it. Note that this case can only + # occur if the producer is an AccessNode. + if producer_edge.data.data == inter_name: + producer_edge.data.data = new_inter_name + + # Regardless of the intrinsic direction of the Memlet, the subset we care about + # is always `dst_subset`. + if is_scalar: + producer_edge.data.dst_subset = "0" + elif producer_edge.data.dst_subset is not None: + # Since we now write into a smaller memory patch, we must + # compensate for that. We do this by substracting where the write + # originally had begun. + producer_edge.data.dst_subset.offset(producer_offset, negative=True) + producer_edge.data.dst_subset.pop(squeezed_dims) + + # Now after we have handled the input of the new intermediate node, + # we must handle its output. For this we have to "inject" the newly + # created intermediate into the second map. We do this by finding + # the input connectors on the map entry, such that we know where we + # have to reroute inside the Map. + # NOTE: Assumes that map (if connected is the direct neighbour). + conn_names: Set[str] = set() + for inter_node_out_edge in state.out_edges(inter_node): + if inter_node_out_edge.dst == second_map_entry: + assert inter_node_out_edge.dst_conn.startswith("IN_") + conn_names.add(inter_node_out_edge.dst_conn) + else: + # If we found another target than the second map entry from the + # intermediate node it means that the node _must_ survive, + # i.e. we are not in exclusive mode. + assert not is_exclusive_set + + # Now we will reroute the connections inside the second map, i.e. + # instead of consuming the old intermediate node, they will now + # consume the new intermediate node. + for in_conn_name in conn_names: + out_conn_name = "OUT_" + in_conn_name[3:] + + for inner_edge in state.out_edges_by_connector(second_map_entry, out_conn_name): + # As for the producer side, we now read from a smaller array, + # So we must offset them, we use the original edge for this. + assert inner_edge.data.src_subset is not None + consumer_offset = self.compute_offset_subset( + original_subset=inner_edge.data.src_subset, + intermediate_desc=inter_desc, + map_params=map_params, + producer_offset=producer_offset, + ) + + # Now create the memlet for the new consumer. To make sure that we get all attributes + # of the Memlet we make a deep copy of it. There is a tricky part here, we have to + # access `src_subset` however, this is only correctly set once it is put inside the + # SDFG. Furthermore, we have to make sure that the Memlet does not change its direction. + # i.e. that the association of `subset` and `other_subset` does not change. For this + # reason we only modify `.data` attribute of the Memlet if its name refers to the old + # intermediate. Furthermore, to play it safe, we only access the subset, `src_subset` + # after we have inserted it to the SDFG. + new_inner_memlet = copy.deepcopy(inner_edge.data) + if inner_edge.data.data == inter_name: + new_inner_memlet.data = new_inter_name + + # Now we replace the edge from the SDFG. + state.remove_edge(inner_edge) + new_inner_edge = state.add_edge( + new_inter_node, + None, + inner_edge.dst, + inner_edge.dst_conn, + new_inner_memlet, + ) + + # Now modifying the Memlet, we do it after the insertion to make + # sure that the Memlet was properly initialized. + if is_scalar: + new_inner_memlet.subset = "0" + elif new_inner_memlet.src_subset is not None: + # TODO(phimuell): Figuring out if `src_subset` is None is an error. + new_inner_memlet.src_subset.offset(consumer_offset, negative=True) + new_inner_memlet.src_subset.pop(squeezed_dims) + + # Now we have to make sure that all consumers are properly updated. + for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children(include_self=False): + consumer_edge = consumer_tree.edge + + # We only modify the data if the Memlet refers to the old intermediate data. + # We can not do this unconditionally, because it might change the intrinsic + # direction of a Memlet and then `src_subset` would at the next `try_initialize` + # be wrong. Note that this case only occurs if the destination is an AccessNode. + if consumer_edge.data.data == inter_name: + consumer_edge.data.data = new_inter_name + + # Now we have to adapt the subsets. + if is_scalar: + consumer_edge.data.src_subset = "0" + elif consumer_edge.data.src_subset is not None: + # TODO(phimuell): Figuring out if `src_subset` is None is an error. + consumer_edge.data.src_subset.offset(consumer_offset, negative=True) + consumer_edge.data.src_subset.pop(squeezed_dims) + + # The edge that leaves the second map entry was already deleted. We now delete + # the edges that connected the intermediate node with the second map entry. + for edge in list(state.in_edges_by_connector(second_map_entry, in_conn_name)): + assert edge.src == inter_node + state.remove_edge(edge) + second_map_entry.remove_in_connector(in_conn_name) + second_map_entry.remove_out_connector(out_conn_name) + + if is_exclusive_set: + # In exclusive mode the old intermediate node is no longer needed. + # This will also remove `out_edge` from the SDFG. + assert state.degree(inter_node) == 1 + state.remove_edge_and_connectors(out_edge) + state.remove_node(inter_node) + + state.remove_edge(pre_exit_edge) + first_map_exit.remove_in_connector(pre_exit_edge.dst_conn) + first_map_exit.remove_out_connector(out_edge.src_conn) + del sdfg.arrays[inter_name] + + else: + # TODO(phimuell): Lift this restriction + assert pre_exit_edge.data.data == inter_name + + # This is the shared mode, so we have to recreate the intermediate + # node, but this time it is at the exit of the second map. + state.remove_edge(pre_exit_edge) + first_map_exit.remove_in_connector(pre_exit_edge.dst_conn) + + # This is the Memlet that goes from the map internal intermediate + # temporary node to the Map output. This will essentially restore + # or preserve the output for the intermediate node. It is important + # that we use the data that `preExitEdge` was used. + final_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) + final_pre_exit_memlet.other_subset = subsets.Range.from_array(new_inter_desc) + + new_pre_exit_conn = second_map_exit.next_connector() + state.add_edge( + new_inter_node, + None, + second_map_exit, + "IN_" + new_pre_exit_conn, + final_pre_exit_memlet, + ) + state.add_edge( + second_map_exit, + "OUT_" + new_pre_exit_conn, + inter_node, + out_edge.dst_conn, + copy.deepcopy(out_edge.data), + ) + second_map_exit.add_in_connector("IN_" + new_pre_exit_conn) + second_map_exit.add_out_connector("OUT_" + new_pre_exit_conn) + + first_map_exit.remove_out_connector(out_edge.src_conn) + state.remove_edge(out_edge) + + + def compute_reduced_intermediate( + self, + producer_subset: subsets.Range, + inter_desc: dace.data.Data, + ) -> Tuple[Tuple[int, ...], Tuple[int, ...], List[int]]: + """Compute the size of the new (reduced) intermediate. + + `MapFusion` does not only fuses map, but, depending on the situation, also + eliminates intermediate arrays between the two maps. To transmit data between + the two maps a new, but much smaller intermediate is needed. + + :return: The function returns a tuple with three values with the following meaning: + * The raw shape of the reduced intermediate. + * The cleared shape of the reduced intermediate, essentially the raw shape + with all shape 1 dimensions removed. + * Which dimensions of the raw shape have been removed to get the cleared shape. + + :param producer_subset: The subset that was used to write into the intermediate. + :param inter_desc: The data descriptor for the intermediate. + """ + assert producer_subset is not None + + # Over approximation will leave us with some unneeded size one dimensions. + # If they are removed some dace transformations (especially auto optimization) + # will have problems. + new_inter_shape_raw = symbolic.overapproximate(producer_subset.size()) + inter_shape = inter_desc.shape + if not self.strict_dataflow: + squeezed_dims: List[int] = [] # These are the dimensions we removed. + new_inter_shape: List[int] = [] # This is the final shape of the new intermediate. + for dim, (proposed_dim_size, full_dim_size) in enumerate(zip(new_inter_shape_raw, inter_shape)): + if full_dim_size == 1: # Must be kept! + new_inter_shape.append(proposed_dim_size) + elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. + squeezed_dims.append(dim) + else: + new_inter_shape.append(proposed_dim_size) + else: + squeezed_dims = [] + new_inter_shape = list(new_inter_shape_raw) + + return (tuple(new_inter_shape_raw), tuple(new_inter_shape), squeezed_dims) + + + def compute_offset_subset( + self, + original_subset: subsets.Range, + intermediate_desc: data.Data, + map_params: List[str], + producer_offset: Union[subsets.Range, None], + ) -> subsets.Range: + """Computes the memlet to correct read and writes of the intermediate. + + This is the value that must be substracted from the memlets to adjust, i.e + (`memlet_to_adjust(correction, negative=True)`). If `producer_offset` is + `None` then the function computes the correction that should be applied to + the producer memlets, i.e. the memlets of the tree converging at + `intermediate_node`. If `producer_offset` is given, it should be the output + of the previous call to this function, with `producer_offset=None`. In this + case the function computes the correction for the consumer side, i.e. the + memlet tree that originates at `intermediate_desc`. + + :param original_subset: The original subset that was used to write into the + intermediate, must be renamed to the final map parameter. + :param intermediate_desc: The original intermediate data descriptor. + :param map_params: The parameter of the final map. + :param producer_offset: The correction that was applied to the producer side. + """ + assert not isinstance(intermediate_desc, data.View) + final_offset: subsets.Range = None + if isinstance(intermediate_desc, data.Scalar): + # If the intermediate was a scalar, then it will remain a scalar. + # Thus there is no correction that we must apply. + return subsets.Range.from_string("0") + + elif isinstance(intermediate_desc, data.Array): + basic_offsets = original_subset.min_element() + offset_list = [] + for d in range(original_subset.dims()): + d_range = subsets.Range([original_subset[d]]) + if d_range.free_symbols.intersection(map_params): + offset_list.append(d_range[0]) + else: + offset_list.append((basic_offsets[d], basic_offsets[d], 1)) + final_offset = subsets.Range(offset_list) + + else: + raise TypeError(f"Does not know how to compute the subset offset for '{type(intermediate_desc).__name__}'.") + + if producer_offset is not None: + # Here we are correcting some parts that over approximate (which partially + # does under approximate) might screw up. Consider two maps, the first + # map only writes the subset `[:, 2:6]`, thus the new intermediate will + # have shape `(1, 4)`. Now also imagine that the second map only reads + # the elements `[:, 3]`. From this we see that we can only correct the + # consumer side if we also take the producer side into consideration! + # See also the `transformations/mapfusion_test.py::test_offset_correction_*` + # tests for more. + final_offset.offset( + final_offset.offset_new( + producer_offset, + negative=True, + ), + negative=True, + ) + return final_offset + + + def can_topologically_be_fused( + self, + first_map_entry: nodes.MapEntry, + second_map_entry: nodes.MapEntry, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + permissive: bool = False, + ) -> Optional[Dict[str, str]]: + """Performs basic checks if the maps can be fused. + + This function only checks constrains that are common between serial and + parallel map fusion process, which includes: + * The scope of the maps. + * The scheduling of the maps. + * The map parameters. + + :return: If the maps can not be topologically fused the function returns `None`. + If they can be fused the function returns `dict` that describes parameter + replacement, see `find_parameter_remapping()` for more. + + :param first_map_entry: The entry of the first (in serial case the top) map. + :param second_map_exit: The entry of the second (in serial case the bottom) map. + :param graph: The SDFGState in which the maps are located. + :param sdfg: The SDFG itself. + :param permissive: Currently unused. + """ + if self.only_inner_maps and self.only_toplevel_maps: + raise ValueError("Only one of `only_inner_maps` and `only_toplevel_maps` is allowed per MapFusion instance.") + + # Ensure that both have the same schedule + if first_map_entry.map.schedule != second_map_entry.map.schedule: + return None + + # Fusing is only possible if the two entries are in the same scope. + scope = graph.scope_dict() + if scope[first_map_entry] != scope[second_map_entry]: + return None + elif self.only_inner_maps: + if scope[first_map_entry] is None: + return None + elif self.only_toplevel_maps: + if scope[first_map_entry] is not None: + return None + + # We will now check if we can rename the Map parameter of the second Map such that they + # match the one of the first Map. + param_repl = self.find_parameter_remapping(first_map=first_map_entry.map, second_map=second_map_entry.map) + return param_repl + + + def has_inner_read_write_dependency( + self, + first_map_entry: nodes.MapEntry, + second_map_entry: nodes.MapEntry, + state: SDFGState, + sdfg: SDFG, + ) -> bool: + """This function tests if there are dependency inside the Maps. + + The function will scan and anaysize the body of the two Maps and look for + inconsistencies. To detect them the function will scan the body of the maps + and examine the all AccessNodes and apply the following rules: + * If an AccessNode refers to a View, it is ignored. Because the source is + either on the outside, in which case `has_read_write_dependency()` + takes care of it, or the data source is inside the Map body itself. + * An inconsistency is detected, if in each bodies there exists an AccessNode + that refer to the same data. + * An inconsistency is detected, if there exists an AccessNode that refers + to non transient data. This is an implementation detail and could be + lifted. + + Note that some of the restrictions of this function could be relaxed by + performing more analysis. + + :return: The function returns `True` if an inconsistency has been found. + + :param first_map_entry: The entry node of the first map. + :param second_map_entry: The entry node of the second map. + :param state: The state on which we operate. + :param sdfg: The SDFG on which we operate. + """ + first_map_body = state.scope_subgraph(first_map_entry, False, False) + second_map_body = state.scope_subgraph(second_map_entry, False, False) + + # Find the data that is internally referenced. Because of the first rule above, + # we filter all views above. + first_map_body_data, second_map_body_data = [ + { + dnode.data + for dnode in map_body.nodes() + if isinstance(dnode, nodes.AccessNode) and not self.is_view(dnode, sdfg) + } + for map_body in [first_map_body, second_map_body] + ] + + # If there is data that is referenced in both, then we consider this as an error + # this is the second rule above. + if not first_map_body_data.isdisjoint(second_map_body_data): + return True + + # We consider it as a problem if any map refers to non-transient data. + # This is an implementation detail and could be dropped if we do further + # analysis. + if any( + not sdfg.arrays[data].transient + for data in first_map_body_data.union(second_map_body_data) + ): + return True + + return False + + + def has_read_write_dependency( + self, + first_map_entry: nodes.MapEntry, + second_map_entry: nodes.MapEntry, + param_repl: Dict[str, str], + state: SDFGState, + sdfg: SDFG, + ) -> bool: + """Test if there is a read write dependency between the two maps to be fused. + + The function checks three different things. + * The function will make sure that there is no read write dependency between + the input and output of the fused maps. For that it will inspect the + respective subsets of the inputs of the MapEntry of the first and the + outputs of the MapExit node of the second map. + * The second part partially checks the intermediate nodes, it mostly ensures + that there are not views and that they are not used as output of the + combined map. Note that it is allowed that an intermediate node is also + an input to the first map. + * In case an intermediate node, is also used as input node of the first map, + it is forbidden that the data is used as output of the second map, the + function will do additional checks. This is needed as the partition function + only checks the data consumption of the second map can be satisfied by the + data production of the first map, it ignores any potential reads made by + the first map's MapEntry. + + :return: `True` if there is a conflict between the maps that can not be handled. + If there is no conflict or if the conflict can be handled `False` is returned. + + :param first_map_entry: The entry node of the first map. + :param second_map_entry: The entry node of the second map. + :param param_repl: Dict that describes how to rename the parameters of the second Map. + :param state: The state on which we operate. + :param sdfg: The SDFG on which we operate. + """ + first_map_exit: nodes.MapExit = state.exit_node(first_map_entry) + second_map_exit: nodes.MapExit = state.exit_node(second_map_entry) + + # Get the read and write sets of the different maps, note that Views + # are not resolved yet. + access_sets: List[Dict[str, nodes.AccessNode]] = [] + for scope_node in [first_map_entry, first_map_exit, second_map_entry, second_map_exit]: + access_set: Set[nodes.AccessNode] = self.get_access_set(scope_node, state) + access_sets.append({node.data: node for node in access_set}) + # If two different access nodes of the same scoping node refers to the + # same data, then we consider this as a dependency we can not handle. + # It is only a problem for the intermediate nodes and might be possible + # to handle, but doing so is hard, so we just forbid it. + if len(access_set) != len(access_sets[-1]): + return True + read_map_1, write_map_1, read_map_2, write_map_2 = access_sets + + # It might be possible that there are views, so we have to resolve them. + # We also already get the name of the data container. + # Note that `len(real_read_map_1) <= len(read_map_1)` holds because of Views. + resolved_sets: List[Set[str]] = [] + for unresolved_set in [read_map_1, write_map_1, read_map_2, write_map_2]: + resolved_sets.append({ + self.track_view(node, state, sdfg).data if self.is_view(node, sdfg) else node.data + for node in unresolved_set.values() + }) + # If the resolved and unresolved names do not have the same length. + # Then different views point to the same location, which we forbid + if len(unresolved_set) != len(resolved_sets[-1]): return False - # Check map ranges - perm = self.find_permutation(first_map_entry.map, second_map_entry.map) - if perm is None: + real_read_map_1, real_write_map_1, real_read_map_2, real_write_map_2 = resolved_sets + + # We do not allow that the first and second map each write to the same data. + # This essentially ensures that an intermediate can not be used as output of + # the second map at the same time. It is actually stronger as it does not + # take their role into account. + if not real_write_map_1.isdisjoint(real_write_map_2): + return True + + # These are the names (unresolved) and the access nodes of the data that is used + # to transmit information between the maps. The partition function ensures that + # these nodes are directly connected to the two maps. + exchange_names: Set[str] = set(write_map_1.keys()).intersection(read_map_2.keys()) + exchange_nodes: Set[nodes.AccessNode] = set(write_map_1.values()).intersection(read_map_2.values()) + + # If the number are different then a data is accessed through different + # AccessNodes. We could analyse this, but we will consider this as a data race. + if len(exchange_names) != len(exchange_nodes): + return True + assert all(exchange_node.data in exchange_names for exchange_node in exchange_nodes) + + # For simplicity we assume that the nodes used for exchange are not views. + if any(self.is_view(exchange_node, sdfg) for exchange_node in exchange_nodes): + return True + + # This is the names of the node that are used as input of the first map and + # as output of the second map. We have to ensure that there is no data + # dependency between these nodes. + # NOTE: This set is not required to be empty. It might look as this would + # create a data race, but it is save. The reason is because all data has + # to pass through the intermediate we create, this will separate the reads + # from the writes. + fused_inout_data_names: Set[str] = set(read_map_1.keys()).intersection(write_map_2.keys()) + + # If a data container is used as input and output then it can not be a view (simplicity) + if any(self.is_view(read_map_1[name], sdfg) for name in fused_inout_data_names): + return True + + # A data container can not be used as output (of the second as well as the + # combined map) and as intermediate. If we would allow that the map would + # have two output nodes one the original one and the second is the created + # node that is created because the intermediate is shared. + # TODO(phimuell): Handle this case. + if not fused_inout_data_names.isdisjoint(exchange_names): + return True + + # While it is forbidden that a data container, used as intermediate, is also + # used as output of the second map. It is allowed that the data container + # is used as intermediate and as input of the first map. The partition only + # checks that the data dependencies are mean, i.e. what is read by the second + # map is also computed (written to the intermediate) it does not take into + # account the first map's read to the data container. + # To make an example: The partition function will make sure that if the + # second map reads index `i` from the intermediate that the first map writes + # to that index. But it will not care if the first map reads (through its + # MapEntry) index `i + 1`. In order to be valid me must ensure that the first + # map's reads and writes to the intermediate are pointwise. + # Note that we only have to make this check if it is also an intermediate node. + # Because if it is not read by the second map it is not a problem as the node + # will end up as an pure output node anyway. + read_write_map_1 = set(read_map_1.keys()).intersection(write_map_1.keys()) + datas_to_inspect = read_write_map_1.intersection(exchange_names) + for data_to_inspect in datas_to_inspect: + + # Now get all subsets of the data container that the first map reads + # from or writes to and check if they are pointwise. + all_subsets: List[subsets.Subset] = [] + all_subsets.extend( + self.find_subsets( + node=read_map_1[data_to_inspect], + scope_node=first_map_entry, + state=state, + sdfg=sdfg, + param_repl=None, + )) + all_subsets.extend( + self.find_subsets( + node=write_map_1[data_to_inspect], + scope_node=first_map_exit, + state=state, + sdfg=sdfg, + param_repl=None, + )) + if not self.test_if_subsets_are_point_wise(all_subsets): + return True + + # If there is no intersection between the input and output data, then we can + # we have nothing to check. + if len(fused_inout_data_names) == 0: return False - # Check if any intermediate transient is also going to another location - second_inodes = set(e.src for e in graph.in_edges(second_map_entry) if isinstance(e.src, nodes.AccessNode)) - transients_to_remove = intermediate_nodes & second_inodes - # if any(e.dst != second_map_entry for n in transients_to_remove - # for e in graph.out_edges(n)): - if any(graph.out_degree(n) > 1 for n in transients_to_remove): - return False + # Now we inspect if there is a read write dependency, between data that is + # used as input and output of the fused map. There is no problem is they + # are pointwise, i.e. in each iteration the same locations are accessed. + # Essentially they all boil down to `a += 1`. + for inout_data_name in fused_inout_data_names: + all_subsets: List[subsets.Subset] = [] + # The subsets that define reading are given by the first map's entry node + all_subsets.extend( + self.find_subsets( + node=read_map_1[inout_data_name], + scope_node=first_map_entry, + state=state, + sdfg=sdfg, + param_repl=None, + )) + # While the subsets defining writing are given by the second map's exit + # node, there we also have to apply renaming. + all_subsets.extend( + self.find_subsets( + node=write_map_2[inout_data_name], + scope_node=second_map_exit, + state=state, + sdfg=sdfg, + param_repl=param_repl, + )) + # Now we can test if these subsets are point wise + if not self.test_if_subsets_are_point_wise(all_subsets): + return True + + # No read write dependency was found. + return False - # Create a dict that maps parameters of the first map to those of the - # second map. - params_dict = {} - for _index, _param in enumerate(second_map_entry.map.params): - params_dict[_param] = first_map_entry.map.params[perm[_index]] - out_memlets = [e.data for e in graph.in_edges(first_map_exit)] + def test_if_subsets_are_point_wise(self, subsets_to_check: List[subsets.Subset]) -> bool: + """Point wise means that they are all the same. - # Check that input set of second map is provided by the output set - # of the first map, or other unrelated maps - for second_edge in graph.out_edges(second_map_entry): - # NOTE: We ignore edges that do not carry data (e.g., connecting a tasklet with no inputs to the MapEntry) - if second_edge.data.is_empty(): - continue - # Memlets that do not come from one of the intermediate arrays - if second_edge.data.data not in intermediate_data: - # however, if intermediate_data eventually leads to - # second_memlet.data, need to fail. - for _n in intermediate_nodes: - source_node = _n - destination_node = graph.memlet_path(second_edge)[0].src - # NOTE: Assumes graph has networkx version - if destination_node in nx.descendants(graph._nx, source_node): - return False - continue + If a series of subsets are point wise it means that all Memlets, access + the same data. This is an important property because the whole map fusion + is build upon this. + If the subsets originates from different maps, then they must have been + renamed. - provided = False + :param subsets_to_check: The list of subsets that should be checked. + """ + assert len(subsets_to_check) > 1 + + # We will check everything against the master subset. + master_subset = subsets_to_check[0] + for ssidx in range(1, len(subsets_to_check)): + subset = subsets_to_check[ssidx] + if isinstance(subset, subsets.Indices): + subset = subsets.Range.from_indices(subset) + # Do we also need the reverse? See below why. + if any(r != (0, 0, 1) for r in subset.offset_new(master_subset, negative=True)): + return False + else: + # The original code used `Range.offset` here, but that one had trouble + # for `r1 = 'j, 0:10'` and `r2 = 'j, 0`. The solution would be to test + # symmetrically, i.e. `r1 - r2` and `r2 - r1`. However, if we would + # have `r2_1 = 'j, 0:10'` it consider it as failing, which is not + # what we want. Thus we will use symmetric cover. + if not master_subset.covers(subset): + return False + if not subset.covers(master_subset): + return False + + # All subsets are equal to the master subset, thus they are equal to each other. + # This means that the data accesses, described by this transformation is + # point wise + return True - # Compute second subset with respect to first subset's symbols - sbs_permuted = dcpy(second_edge.data.subset) - if sbs_permuted: - # Create intermediate dicts to avoid conflicts, such as {i:j, j:i} - symbolic.safe_replace(params_dict, lambda m: sbs_permuted.replace(m)) - for first_memlet in out_memlets: - if first_memlet.data != second_edge.data.data: + def is_shared_data( + self, + data: nodes.AccessNode, + state: dace.SDFGState, + sdfg: dace.SDFG, + ) -> bool: + """Tests if `data` is shared data, i.e. it can not be removed from the SDFG. + + Depending on the situation, the function will not perform a scan of the whole SDFG: + 1) If `assume_always_shared` was set to `True`, the function will return `True` unconditionally. + 2) If `data` is non transient then the function will return `True`, as non transient data + must be reconstructed always. + 3) If the AccessNode `data` has more than one outgoing edge or more than one incoming edge + it is classified as shared. + 2) If `FindSingleUseData` is in the pipeline it will be used and no scan will be performed. + 3) The function will perform a scan. + + :param data: The transient that should be checked. + :param state: The state in which the fusion is performed. + :param sdfg: The SDFG in which we want to perform the fusing. + + """ + # `assume_always_shared` takes precedence. + if self.assume_always_shared: + return True + + # If `data` is non transient then return `True` as the intermediate can not be removed. + if not data.desc(sdfg).transient: + return True + + # This means the data is consumed by multiple Maps, through the same AccessNode, in this state + # Note currently multiple incoming edges are not handled, but in the spirit of this function + # we consider such AccessNodes as shared, because we can not remove the intermediate. + if state.out_degree(data) > 1: + return True + if state.in_degree(data) > 1: + return True + + # NOTE: Actually, if this transformation is run through the `FullMapFusion` pass, it should + # read the results from `FindSingelUseData`, that was computed because it is a dependent + # pass through the `self._pipeline_results` which is set by the `SingleStateTransformation`. + # However, this member is only set during when `apply()` is called, but not during + # `can_be_applied()`, see [issue#1911](https://github.com/spcl/dace/issues/1911). + # Because, the whole goal of this separation of scanning and fusion was to make the + # transformation stateless, the member `_single_use_data` was introduced. If it is set + # then we use it otherwise we use the scanner. + # This value is set for example by the `FullMapFusion` pass. + # TODO(phimuell): Change this once the issue is resolved. + if self._single_use_data is not None: + assert sdfg in self._single_use_data, f"`_single_use_data` was set, but does not contain information about the SDFG '{sdfg.name}'." + single_use_data: Set[str] = self._single_use_data[sdfg] + return data.data not in single_use_data + + # We have to perform the full scan of the SDFG. + return self._scan_sdfg_if_data_is_shared(data=data, state=state, sdfg=sdfg) + + + def _scan_sdfg_if_data_is_shared( + self, + data: nodes.AccessNode, + state: dace.SDFGState, + sdfg: dace.SDFG, + ) -> None: + """Scans `sdfg` to determine if `data` is shared. + + Essentially, this function determine, if the intermediate AccessNode `data` is + can be removed or if it has to be restored as output of the Map. + A data descriptor is classified as shared if any of the following is true: + - `data` is non transient data. + - `data` has at most one incoming and/or outgoing edge. + - There are other AccessNodes beside `data` that refer to the same data. + - The data is accessed on an interstate edge. + + This function should not be called directly. Instead it is called indirectly + by `is_shared_data()` if there is no short cut. + + :param data: The AccessNode that should checked if it is shared. + :param sdfg: The SDFG for which the set of shared data should be computed. + """ + if not data.desc(sdfg).transient: + return True + + # See description in `is_shared_data()` for more. + if state.out_degree(data) > 1: + return True + if state.in_degree(data) > 1: + return True + + data_name: str = data.data + for state in sdfg.states(): + for dnode in state.data_nodes(): + if dnode is data: + # We have found the `data` AccessNode, which we must ignore. continue + if dnode.data == data_name: + # We found a different AccessNode that refers to the same data + # as `data`. Thus `data` is shared. + return True + + # Test if the data is referenced in the interstate edges. + for edge in sdfg.edges(): + if data_name in edge.data.free_symbols: + # The data is used in the inter state edges. So it is shared. + return True + + # Test if the data is referenced inside a control flow, such as a conditional + # block or loop condition. + for cfr in sdfg.all_control_flow_regions(): + if data_name in cfr.used_symbols(all_symbols=True, with_contents=False): + return True + + # The `data` is not used anywhere else, thus `data` is not shared. + return False - # If there is a covered subset, it is provided - if first_memlet.subset.covers(sbs_permuted): - provided = True - break - # If none of the output memlets of the first map provide the info, - # fail. - if provided is False: - return False + def find_parameter_remapping(self, first_map: nodes.Map, second_map: nodes.Map) -> Optional[Dict[str, str]]: + """Computes the parameter remapping for the parameters of the _second_ map. + + The returned `dict` maps the parameters of the second map (keys) to parameter + names of the first map (values). Because of how the replace function works + the `dict` describes how to replace the parameters of the second map + with parameters of the first map. + Parameters that already have the correct name and compatible range, are not + included in the return value, thus the keys and values are always different. + If no renaming at is _needed_, i.e. all parameter have the same name and range, + then the function returns an empty `dict`. + If no remapping exists, then the function will return `None`. - # Checking for stencil pattern and common input/output data - # (after fusing the maps) - first_map_inputnodes = { - e.src: e.src.data - for e in graph.in_edges(first_map_entry) if isinstance(e.src, nodes.AccessNode) + :param first_map: The first map (these parameters will be replaced). + :param second_map: The second map, these parameters acts as source. + """ + + # The parameter names + first_params: List[str] = first_map.params + second_params: List[str] = second_map.params + + if len(first_params) != len(second_params): + return None + + # The ranges, however, we apply some post processing to them. + simp = lambda e: symbolic.simplify_ext(symbolic.simplify(e)) + first_rngs: Dict[str, Tuple[Any, Any, Any]] = { + param: tuple(simp(r) for r in rng) + for param, rng in zip(first_params, first_map.range) } - input_views = set() - viewed_inputnodes = dict() - for n in first_map_inputnodes.keys(): - if isinstance(n.desc(sdfg), data.View): - input_views.add(n) - for v in input_views: - del first_map_inputnodes[v] - e = sdutil.get_view_edge(graph, v) - if e: - while not isinstance(e.src, nodes.AccessNode): - e = graph.memlet_path(e)[0] - first_map_inputnodes[e.src] = e.src.data - viewed_inputnodes[e.src.data] = v - second_map_outputnodes = { - e.dst: e.dst.data - for e in graph.out_edges(second_map_exit) if isinstance(e.dst, nodes.AccessNode) + second_rngs: Dict[str, Tuple[Any, Any, Any]] = { + param: tuple(simp(r) for r in rng) + for param, rng in zip(second_params, second_map.range) } - output_views = set() - viewed_outputnodes = dict() - for n in second_map_outputnodes: - if isinstance(n.desc(sdfg), data.View): - output_views.add(n) - for v in output_views: - del second_map_outputnodes[v] - e = sdutil.get_view_edge(graph, v) - if e: - while not isinstance(e.dst, nodes.AccessNode): - e = graph.memlet_path(e)[-1] - second_map_outputnodes[e.dst] = e.dst.data - viewed_outputnodes[e.dst.data] = v - common_data = set(first_map_inputnodes.values()).intersection(set(second_map_outputnodes.values())) - if common_data: - input_data = [viewed_inputnodes[d].data if d in viewed_inputnodes.keys() else d for d in common_data] - input_accesses = [ - graph.memlet_path(e)[-1].data.src_subset for e in graph.out_edges(first_map_entry) - if e.data.data in input_data - ] - if len(input_accesses) > 1: - for i, a in enumerate(input_accesses[:-1]): - for b in input_accesses[i + 1:]: - if isinstance(a, subsets.Indices): - c = subsets.Range.from_indices(a) - c.offset(b, negative=True) - else: - c = a.offset_new(b, negative=True) - for r in c: - if r != (0, 0, 1): - return False - - output_data = [viewed_outputnodes[d].data if d in viewed_outputnodes.keys() else d for d in common_data] - output_accesses = [ - graph.memlet_path(e)[0].data.dst_subset for e in graph.in_edges(second_map_exit) - if e.data.data in output_data - ] - - # Compute output accesses with respect to first map's symbols - oacc_permuted = [dcpy(a) for a in output_accesses] - for a in oacc_permuted: - # Create intermediate dicts to avoid conflicts, such as {i:j, j:i} - symbolic.safe_replace(params_dict, lambda m: a.replace(m)) - - a = input_accesses[0] - for b in oacc_permuted: - if isinstance(a, subsets.Indices): - c = subsets.Range.from_indices(a) - c.offset(b, negative=True) - else: - c = a.offset_new(b, negative=True) - for r in c: - if r != (0, 0, 1): - return False - # Success - return True + # Parameters of the second map that have not yet been matched to a parameter + # of the first map and vice versa. + unmapped_second_params: Set[str] = set(second_params) + unused_first_params: Set[str] = set(first_params) + + # This is the result (`second_param -> first_param`), note that if no renaming + # is needed then the parameter is not present in the mapping. + final_mapping: Dict[str, str] = {} + + # First we identify the parameters that already have the correct name. + for param in set(first_params).intersection(second_params): + first_rng = first_rngs[param] + second_rng = second_rngs[param] + + if first_rng == second_rng: + # They have the same name and the same range, this is already a match. + # Because the names are already the same, we do not have to enter them + # in the `final_mapping` + unmapped_second_params.discard(param) + unused_first_params.discard(param) + + # Check if no remapping is needed. + if len(unmapped_second_params) == 0: + return {} + + # Now we go through all the parameters that we have not mapped yet. + # All of them will result in a remapping. + for unmapped_second_param in unmapped_second_params: + second_rng = second_rngs[unmapped_second_param] + assert unmapped_second_param not in final_mapping + + # Now look in all not yet used parameters of the first map which to use. + for candidate_param in unused_first_params: + candidate_rng = first_rngs[candidate_param] + if candidate_rng == second_rng: + final_mapping[unmapped_second_param] = candidate_param + unused_first_params.discard(candidate_param) + break + else: + # We did not find a candidate, so the remapping does not exist + return None + + assert len(unused_first_params) == 0 + assert len(final_mapping) == len(unmapped_second_params) + return final_mapping + + + def rename_map_parameters( + self, + first_map: nodes.Map, + second_map: nodes.Map, + second_map_entry: nodes.MapEntry, + state: SDFGState, + ) -> None: + """Replaces the map parameters of the second map with names from the first. + + The replacement is done in a safe way, thus `{'i': 'j', 'j': 'i'}` is + handled correct. The function assumes that a proper replacement exists. + The replacement is computed by calling `self.find_parameter_remapping()`. + + :param first_map: The first map (these are the final parameter). + :param second_map: The second map, this map will be replaced. + :param second_map_entry: The entry node of the second map. + :param state: The SDFGState on which we operate. + """ + # Compute the replacement dict. + repl_dict: Dict[str, str] = self.find_parameter_remapping(first_map=first_map, second_map=second_map) + + if repl_dict is None: + raise RuntimeError("The replacement does not exist") + if len(repl_dict) == 0: + return + + second_map_scope = state.scope_subgraph(entry_node=second_map_entry) + # Why is this thing is symbolic and not in replace? + symbolic.safe_replace( + mapping=repl_dict, + replace_callback=second_map_scope.replace_dict, + ) + + # For some odd reason the replace function does not modify the range and + # parameter of the map, so we will do it the hard way. + second_map.params = copy.deepcopy(first_map.params) + second_map.range = copy.deepcopy(first_map.range) + - def apply(self, graph: SDFGState, sdfg: SDFG): + def is_node_reachable_from( + self, + graph: Union[dace.SDFG, dace.SDFGState], + begin: nodes.Node, + end: nodes.Node, + ) -> bool: + """Test if the node `end` can be reached from `begin`. + + Essentially the function starts a DFS at `begin`. If an edge is found that lead + to `end` the function returns `True`. If the node is never found `False` is + returned. + + :param graph: The graph to operate on. + :param begin: The start of the DFS. + :param end: The node that should be located. """ - This method applies the mapfusion transformation. - Other than the removal of the second map entry node (SME), and the first - map exit (FME) node, it has the following side effects: - 1. Any transient adjacent to both FME and SME with degree = 2 will be removed. - The tasklets that use/produce it shall be connected directly with a - scalar/new transient (if the dataflow is more than a single scalar) + def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: + return (edge.dst for edge in graph.out_edges(node)) + + to_visit: List[nodes.Node] = [begin] + seen: Set[nodes.Node] = set() - 2. If this transient is adjacent to FME and SME and has other - uses, it will be adjacent to the new map exit post fusion. - Tasklet-> Tasklet edges will ALSO be added as mentioned above. + while len(to_visit) > 0: + node: nodes.Node = to_visit.pop() + if node == end: + return True + elif node not in seen: + to_visit.extend(next_nodes(node)) + seen.add(node) + + # We never found `end` + return False - 3. If an access node is adjacent to FME but not SME, it will be - adjacent to new map exit post fusion. - 4. If an access node is adjacent to SME but not FME, it will be - adjacent to the new map entry node post fusion. + def _is_data_accessed_downstream( + self, + data: str, + graph: dace.SDFGState, + begin: nodes.Node, + ) -> bool: + """Tests if there is an AccessNode for `data` downstream of `begin`. + Essentially, this function starts a DFS at `begin` and checks every + AccessNode that is reachable from it. If it finds such a node it will + check if it refers to `data` and if so, it will return `True`. + If no such node is found it will return `False`. + Note that the node `begin` will be ignored. + + :param data: The name of the data to look for. + :param graph: The graph to explore. + :param begin: The node to start exploration; The node itself is ignored. """ - first_exit = self.first_map_exit - first_entry = graph.entry_node(first_exit) - second_entry = self.second_map_entry - second_exit = graph.exit_node(second_entry) - - intermediate_nodes = set() - for _, _, dst, _, _ in graph.out_edges(first_exit): - intermediate_nodes.add(dst) - assert isinstance(dst, nodes.AccessNode) - - # Check if an access node refers to non transient memory, or transient - # is used at another location (cannot erase) - do_not_erase = set() - for node in intermediate_nodes: - if sdfg.arrays[node.data].transient is False: - do_not_erase.add(node) - else: - for edge in graph.in_edges(node): - if edge.src != first_exit: - do_not_erase.add(node) - break - else: - for edge in graph.out_edges(node): - if edge.dst != second_entry: - do_not_erase.add(node) - break - - # Find permutation between first and second scopes - perm = self.find_permutation(first_entry.map, second_entry.map) - params_dict = {} - for index, param in enumerate(first_entry.map.params): - params_dict[param] = second_entry.map.params[perm[index]] - - # Replaces (in memlets and tasklet) the second scope map - # indices with the permuted first map indices. - # This works in two passes to avoid problems when e.g., exchanging two - # parameters (instead of replacing (j,i) and (i,j) to (j,j) and then - # i,i). - second_scope = graph.scope_subgraph(second_entry) - for firstp, secondp in params_dict.items(): - if firstp != secondp: - replace(second_scope, secondp, '__' + secondp + '_fused') - for firstp, secondp in params_dict.items(): - if firstp != secondp: - replace(second_scope, '__' + secondp + '_fused', firstp) - - # Isolate First exit node - ############################ - edges_to_remove = set() - nodes_to_remove = set() - for edge in graph.in_edges(first_exit): - tree = graph.memlet_tree(edge) - access_node = tree.root().edge.dst - if access_node not in do_not_erase: - out_edges = [e for e in graph.out_edges(access_node) if e.dst == second_entry] - # In this transformation, there can only be one edge to the - # second map - assert len(out_edges) == 1 - - # Get source connector to the second map - connector = out_edges[0].dst_conn[3:] - - new_dsts = [] - # Look at the second map entry out-edges to get the new - # destinations - for e in graph.out_edges(second_entry): - if e.src_conn and e.src_conn[4:] == connector: - new_dsts.append(e) - if not new_dsts: # Access node is not used in the second map - nodes_to_remove.add(access_node) - continue + def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: + return (edge.dst for edge in graph.out_edges(node)) + + # Dataflow graph is acyclic, so we do not need to keep a list of + # what we have visited. + to_visit: List[nodes.Node] = list(next_nodes(begin)) + while len(to_visit) > 0: + node = to_visit.pop() + if isinstance(node, nodes.AccessNode) and node.data == data: + return True + to_visit.extend(next_nodes(node)) - # Add a transient scalar/array - self.fuse_nodes(sdfg, graph, edge, new_dsts[0].dst, new_dsts[0].dst_conn, new_dsts[1:]) - - edges_to_remove.add(edge) - - # Remove transient node between the two maps - nodes_to_remove.add(access_node) - else: # The case where intermediate array node cannot be removed - # Node will become an output of the second map exit - out_e = tree.parent.edge - conn = second_exit.next_connector() - graph.add_edge( - second_exit, - 'OUT_' + conn, - out_e.dst, - out_e.dst_conn, - dcpy(out_e.data), - ) - second_exit.add_out_connector('OUT_' + conn) - - graph.add_edge(edge.src, edge.src_conn, second_exit, 'IN_' + conn, dcpy(edge.data)) - second_exit.add_in_connector('IN_' + conn) - - edges_to_remove.add(out_e) - edges_to_remove.add(edge) - - # If the second map needs this node, link the connector - # that generated this to the place where it is needed, with a - # temp transient/scalar for memlet to be generated - for out_e in graph.out_edges(second_entry): - second_memlet_path = graph.memlet_path(out_e) - source_node = second_memlet_path[0].src - if source_node == access_node: - self.fuse_nodes(sdfg, graph, edge, out_e.dst, out_e.dst_conn) - - ### - # First scope exit is isolated and can now be safely removed - for e in edges_to_remove: - graph.remove_edge(e) - graph.remove_nodes_from(nodes_to_remove) - graph.remove_node(first_exit) - - # Isolate second_entry node - ########################### - for edge in graph.in_edges(second_entry): - tree = graph.memlet_tree(edge) - access_node = tree.root().edge.src - if access_node in intermediate_nodes: - # Already handled above, can be safely removed - graph.remove_edge(edge) - continue + return False - # This is an external input to the second map which will now go - # through the first map. - conn = first_entry.next_connector() - graph.add_edge(edge.src, edge.src_conn, first_entry, 'IN_' + conn, dcpy(edge.data)) - first_entry.add_in_connector('IN_' + conn) - graph.remove_edge(edge) - for out_enode in tree.children: - out_e = out_enode.edge - graph.add_edge( - first_entry, - 'OUT_' + conn, - out_e.dst, - out_e.dst_conn, - dcpy(out_e.data), - ) - graph.remove_edge(out_e) - first_entry.add_out_connector('OUT_' + conn) - - # NOTE: Check the second MapEntry for output edges with empty memlets - for edge in graph.out_edges(second_entry): - if edge.data.is_empty(): - graph.remove_edge(edge) - graph.add_edge(first_entry, edge.src_conn, edge.dst, edge.dst_conn, edge.data) - - ### - # Second node is isolated and can now be safely removed - graph.remove_node(second_entry) - - # Fix scope exit to point to the right map - second_exit.map = first_entry.map - - def fuse_nodes(self, sdfg: SDFG, graph: SDFGState, edge, new_dst, new_dst_conn, other_edges=None): - """ Fuses two nodes via memlets and possibly transient arrays. """ - other_edges = other_edges or [] - memlet_path = graph.memlet_path(edge) - access_node = memlet_path[-1].dst - - local_name = "__s%d_n%d%s_n%d%s" % ( - self.state_id, - graph.node_id(edge.src), - edge.src_conn, - graph.node_id(edge.dst), - edge.dst_conn, - ) - # Add intermediate memory between subgraphs. - # If a scalar, uses direct connection. If an array, adds a transient node. - # NOTE: If any of the src/dst nodes is a nested SDFG, treat it as an array. - is_scalar = edge.data.subset.num_elements() == 1 - accesses = ( - [graph.memlet_path(e1)[0].src for e0 in graph.in_edges(access_node) for e1 in graph.memlet_tree(e0)] + - [graph.memlet_path(e1)[-1].dst for e0 in graph.out_edges(access_node) for e1 in graph.memlet_tree(e0)]) - if any(isinstance(a, nodes.NestedSDFG) for a in accesses): - is_scalar = False - if is_scalar: - local_name, _ = sdfg.add_scalar( - local_name, - dtype=access_node.desc(graph).dtype, - transient=True, - storage=dtypes.StorageType.Register, - find_new_name=True, - ) - edge.data.data = local_name - edge.data.subset = "0" - - # If source of edge leads to multiple destinations, redirect all through an access node. - out_edges = list(graph.out_edges_by_connector(edge.src, edge.src_conn)) - if len(out_edges) > 1: - local_node = graph.add_access(local_name) - src_connector = None - - # Add edge that leads to transient node - graph.add_edge(edge.src, edge.src_conn, local_node, None, dcpy(edge.data)) - - for other_edge in out_edges: - if other_edge is not edge: - graph.remove_edge(other_edge) - mem = Memlet(data=local_name, other_subset=other_edge.data.dst_subset) - graph.add_edge(local_node, src_connector, other_edge.dst, other_edge.dst_conn, mem) - else: - local_node = edge.src - src_connector = edge.src_conn - - # update edge data in case source or destination is a scalar access node - test_data = [node.data for node in (edge.src, edge.dst) if isinstance(node, nodes.AccessNode)] - for new_data in test_data: - if isinstance(sdfg.arrays[new_data], data.Scalar): - edge.data.data = new_data - - # If destination of edge leads to multiple destinations, redirect all through an access node. - if other_edges: - # NOTE: If a new local node was already created, reuse it. - if local_node == edge.src: - local_node_out = graph.add_access(local_name) - connector_out = None - else: - local_node_out = local_node - connector_out = src_connector - graph.add_edge(local_node, src_connector, local_node_out, connector_out, - Memlet.from_array(local_name, sdfg.arrays[local_name])) - graph.add_edge(local_node_out, connector_out, new_dst, new_dst_conn, dcpy(edge.data)) - for e in other_edges: - graph.add_edge(local_node_out, connector_out, e.dst, e.dst_conn, dcpy(edge.data)) - else: - # Add edge that leads to the second node - graph.add_edge(local_node, src_connector, new_dst, new_dst_conn, dcpy(edge.data)) + def get_access_set( + self, + scope_node: Union[nodes.MapEntry, nodes.MapExit], + state: SDFGState, + ) -> Set[nodes.AccessNode]: + """Computes the access set of a "scope node". + + If `scope_node` is a `MapEntry` it will operate on the set of incoming edges + and if it is an `MapExit` on the set of outgoing edges. The function will + then determine all access nodes that have a connection through these edges + to the scope nodes (edges that does not lead to access nodes are ignored). + The function returns a set that contains all access nodes that were found. + It is important that this set will also contain views. + + :param scope_node: The scope node that should be evaluated. + :param state: The state in which we operate. + """ + if isinstance(scope_node, nodes.MapEntry): + get_edges = lambda node: state.in_edges(node) + other_node = lambda e: e.src else: - local_name, _ = sdfg.add_transient(local_name, - symbolic.overapproximate(edge.data.subset.size()), - dtype=access_node.desc(graph).dtype, - find_new_name=True) - old_edge = dcpy(edge) - local_node = graph.add_access(local_name) - src_connector = None - edge.data.data = local_name - edge.data.subset = ",".join(["0:" + str(s) for s in edge.data.subset.size()]) - # Add edge that leads to transient node - graph.add_edge( - edge.src, - edge.src_conn, - local_node, - None, - dcpy(edge.data), - ) + get_edges = lambda node: state.out_edges(node) + other_node = lambda e: e.dst + access_set: Set[nodes.AccessNode] = { + node + for node in map(other_node, get_edges(scope_node)) if isinstance(node, nodes.AccessNode) + } - # Add edge that leads to the second node - graph.add_edge(local_node, src_connector, new_dst, new_dst_conn, dcpy(edge.data)) + return access_set + + + def find_subsets( + self, + node: nodes.AccessNode, + scope_node: Union[nodes.MapExit, nodes.MapEntry], + state: SDFGState, + sdfg: SDFG, + param_repl: Optional[Dict[str, str]], + ) -> List[subsets.Subset]: + """Finds all subsets that access `node` within `scope_node`. + + The function will not start a search for all consumer/producers. + Instead it will locate the edges which is immediately inside the + map scope. + + :param node: The access node that should be examined. + :param scope_node: We are only interested in data that flows through this node. + :param state: The state in which we operate. + :param sdfg: The SDFG object. + :param param_repl: `dict` that describes the parameter renaming that should be + performed. Can be `None` to skip the processing. + """ + # Is the node used for reading or for writing. + # This influences how we have to proceed. + if isinstance(scope_node, nodes.MapEntry): + outer_edges_to_inspect = [e for e in state.in_edges(scope_node) if e.src == node] + get_subset = lambda e: e.data.src_subset + get_inner_edges = lambda e: state.out_edges_by_connector(scope_node, "OUT_" + e.dst_conn[3:]) + else: + outer_edges_to_inspect = [e for e in state.out_edges(scope_node) if e.dst == node] + get_subset = lambda e: e.data.dst_subset + get_inner_edges = lambda e: state.in_edges_by_connector(scope_node, "IN_" + e.src_conn[4:]) + + found_subsets: List[subsets.Subset] = [] + for edge in outer_edges_to_inspect: + found_subsets.extend(get_subset(e) for e in get_inner_edges(edge)) + assert len(found_subsets) > 0, "Could not find any subsets." + assert not any(subset is None for subset in found_subsets) + + found_subsets = copy.deepcopy(found_subsets) + if param_repl: + for subset in found_subsets: + # Replace happens in place + symbolic.safe_replace(param_repl, subset.replace) + + return found_subsets + + + def is_view( + self, + node: Union[nodes.AccessNode, data.Data], + sdfg: SDFG, + ) -> bool: + """Tests if `node` points to a view or not.""" + node_desc: data.Data = node if isinstance(node, data.Data) else node.desc(sdfg) + return isinstance(node_desc, data.View) + + + def track_view( + self, + view: nodes.AccessNode, + state: SDFGState, + sdfg: SDFG, + ) -> nodes.AccessNode: + """Find the original data of a View. + + Given the View `view`, the function will trace the view back to the original + access node. For convenience, if `view` is not a `View` the argument will be + returned. + + :param view: The view that should be traced. + :param state: The state in which we operate. + :param sdfg: The SDFG on which we operate. + """ - for e in other_edges: - graph.add_edge(local_node, src_connector, e.dst, e.dst_conn, dcpy(edge.data)) + # Test if it is a view at all, if not return the passed node as source. + if not self.is_view(view, sdfg): + return view - # Modify data and memlets on all surrounding edges to match array - for neighbor in graph.all_edges(local_node): - for e in graph.memlet_tree(neighbor): - if e.data.data == local_name: - continue - e.data.data = local_name - e.data.subset.offset(old_edge.data.subset, negative=True) + # This is the node that defines the view. + defining_node = dace.sdfg.utils.get_last_view_node(state, view) + assert isinstance(defining_node, nodes.AccessNode) + assert not self.is_view(defining_node, sdfg) + return defining_node diff --git a/dace/transformation/passes/__init__.py b/dace/transformation/passes/__init__.py index e8f19d181e..98bb4b7602 100644 --- a/dace/transformation/passes/__init__.py +++ b/dace/transformation/passes/__init__.py @@ -4,6 +4,7 @@ from .constant_propagation import ConstantPropagation from .dead_dataflow_elimination import DeadDataflowElimination from .dead_state_elimination import DeadStateElimination +from .full_map_fusion import FullMapFusion from .fusion_inline import FuseStates, InlineSDFGs from .optional_arrays import OptionalArrayInference from .pattern_matching import PatternMatchAndApply, PatternMatchAndApplyRepeated, PatternApplyOnceEverywhere diff --git a/dace/transformation/passes/full_map_fusion.py b/dace/transformation/passes/full_map_fusion.py new file mode 100644 index 0000000000..8b1357611e --- /dev/null +++ b/dace/transformation/passes/full_map_fusion.py @@ -0,0 +1,144 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +from typing import Any, Dict, Optional, Set + +import warnings + +from dace import SDFG, SDFGState, properties, transformation +from dace.transformation import pass_pipeline as ppl +from dace.transformation.dataflow import MapFusion +from dace.transformation.passes import analysis as ap, pattern_matching as pmp + + +@properties.make_properties +@transformation.explicit_cf_compatible +class FullMapFusion(ppl.Pass): + """ + Pass that combines `MapFusion` and `FindSingleUseData` into one. + + Essentially, this function runs `FindSingleUseData` before `MapFusion`, this + will speedup the fusion, as the SDFG has to be scanned only once. + The pass accepts the same options as `MapFusion`, for a detailed description + see there. + + :param only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + :param only_toplevel_maps: Only consider Maps that are at the top. + :param strict_dataflow: Which dataflow mode should be used, see above. + :param assume_always_shared: Assume that all intermediates are shared. + :param validate: Validate the SDFG after the pass as finished. + :param validate_all: Validate the SDFG after every transformation. + + :todo: Implement a faster matcher as the pattern is constant. + """ + + CATEGORY: str = 'Simplification' + + # Settings + only_toplevel_maps = properties.Property( + dtype=bool, + default=False, + desc='Only perform fusing if the Maps are in the top level.', + ) + only_inner_maps = properties.Property( + dtype=bool, + default=False, + desc='Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.', + ) + strict_dataflow = properties.Property( + dtype=bool, + default=True, + desc='If `True` then the transformation will ensure a more stricter data flow.', + ) + assume_always_shared = properties.Property( + dtype=bool, + default=False, + desc='If `True` then all intermediates will be classified as shared.', + ) + + validate = properties.Property( + dtype=bool, + default=True, + desc='If True, validates the SDFG after all transformations have been applied.', + ) + validate_all = properties.Property( + dtype=bool, + default=False, + desc='If True, validates the SDFG after each transformation applies.' + ) + + + def __init__( + self, + only_inner_maps: Optional[bool] = None, + only_toplevel_maps: Optional[bool] = None, + strict_dataflow: Optional[bool] = None, + assume_always_shared: Optional[bool] = None, + validate: Optional[bool] = None, + validate_all: Optional[bool] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if only_toplevel_maps is not None: + self.only_toplevel_maps = only_toplevel_maps + if only_inner_maps is not None: + self.only_inner_maps = only_inner_maps + if strict_dataflow is not None: + self.strict_dataflow = strict_dataflow + if assume_always_shared is not None: + self.assume_always_shared = assume_always_shared + if validate is not None: + self.validate = validate + if validate_all is not None: + self.validate_all = validate_all + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.Scopes | ppl.Modifies.AccessNodes | ppl.Modifies.Memlets + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & (ppl.Modifies.Scopes | ppl.Modifies.AccessNodes | ppl.Modifies.Memlets | ppl.Modifies.States) + + def depends_on(self): + return {ap.FindSingleUseData} + + def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[int]: + """ + Fuses all Maps that can be fused in the SDFG, including its nested SDFGs. + + For driving the fusion the function will construct a `PatternMatchAndApplyRepeated` + object. + + :param sdfg: The SDFG to modify. + :param pipeline_results: The result of previous pipeline steps. The pass expects + at least the result of the `FindSingleUseData`. + :return: The numbers of Maps that were fused or `None` if none were fused. + """ + if ap.FindSingleUseData.__name__ not in pipeline_results: + raise ValueError(f'Expected to find `FindSingleUseData` in `pipeline_results`.') + + fusion = MapFusion( + only_inner_maps=self.only_inner_maps, + only_toplevel_maps=self.only_toplevel_maps, + strict_dataflow=self.strict_dataflow, + assume_always_shared=self.assume_always_shared + ) + + try: + # The short answer why we do this is because `fusion._pipeline_results` is + # only defined during `apply()` and not during `can_be_applied()`. For more + # information see the note in `MapFusion.is_shared_data()` and/or [issue#1911](https://github.com/spcl/dace/issues/1911). + assert fusion._single_use_data is None + fusion._single_use_data = pipeline_results["FindSingleUseData"] + pazz = pmp.PatternMatchAndApplyRepeated( + [fusion], + permissive=False, + validate=False, + validate_all=self.validate_all, + ) + result = pazz.apply_pass(sdfg, pipeline_results) + + finally: + fusion._single_use_data = None + + if self.validate: + sdfg.validate() + + return result diff --git a/tests/buffer_tiling_test.py b/tests/buffer_tiling_test.py index 52477dcc72..03635a7721 100644 --- a/tests/buffer_tiling_test.py +++ b/tests/buffer_tiling_test.py @@ -78,6 +78,7 @@ def _semantic_eq(tile_sizes, program): count = sdfg.apply_transformations(BufferTiling, options={'tile_sizes': tile_sizes}) assert count > 0 + sdfg.validate() sdfg(w3=w3, w5=w5, A=A, B=B2, I=A.shape[0], J=A.shape[1]) assert np.allclose(B1, B2) diff --git a/tests/npbench/polybench/correlation_test.py b/tests/npbench/polybench/correlation_test.py index d743ba528d..a5532cf829 100644 --- a/tests/npbench/polybench/correlation_test.py +++ b/tests/npbench/polybench/correlation_test.py @@ -83,7 +83,8 @@ def run_correlation(device_type: dace.dtypes.DeviceType): # Compute ground truth and validate result corr_ref = ground_truth(M, float_n_ref, data_ref) - assert np.allclose(corr, corr_ref) + diff = corr_ref - corr + assert np.abs(diff).max() <= 10e-10 return sdfg diff --git a/tests/npbench/polybench/heat_3d_test.py b/tests/npbench/polybench/heat_3d_test.py index e058914fd3..75ad902c4b 100644 --- a/tests/npbench/polybench/heat_3d_test.py +++ b/tests/npbench/polybench/heat_3d_test.py @@ -64,10 +64,22 @@ def run_heat_3d(device_type: dace.dtypes.DeviceType): A_ref = np.copy(A) B_ref = np.copy(B) + def count_maps(sdfg: dc.SDFG) -> int: + nb_maps = 0 + for _, state in sdfg.all_nodes_recursive(): + node: dc.SDFGState + for node in state.nodes(): + if isinstance(node, dc.sdfg.nodes.MapEntry): + nb_maps += 1 + return nb_maps + if device_type in {dace.dtypes.DeviceType.CPU, dace.dtypes.DeviceType.GPU}: # Parse the SDFG and apply auto-opt sdfg = heat_3d_kernel.to_sdfg() + initial_maps = count_maps(sdfg) sdfg = auto_optimize(sdfg, device_type) + after_maps = count_maps(sdfg) + assert after_maps < initial_maps, f"Expected less maps, initially {initial_maps} many maps, but after optimization {after_maps}" sdfg(TSTEPS, A, B, N=N) elif device_type == dace.dtypes.DeviceType.FPGA: # Parse SDFG and apply FPGA friendly optimization diff --git a/tests/npbench/polybench/jacobi_2d_test.py b/tests/npbench/polybench/jacobi_2d_test.py index bc2d5a4f2b..61982c427f 100644 --- a/tests/npbench/polybench/jacobi_2d_test.py +++ b/tests/npbench/polybench/jacobi_2d_test.py @@ -47,6 +47,7 @@ def run_jacobi_2d(device_type: dace.dtypes.DeviceType): # Parse the SDFG and apply autopot sdfg = kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) + sdfg(A=A, B=B, TSTEPS=TSTEPS, N=N) elif device_type == dace.dtypes.DeviceType.FPGA: diff --git a/tests/numpy/common.py b/tests/numpy/common.py index 2784c8a0eb..eee176f9c2 100644 --- a/tests/numpy/common.py +++ b/tests/numpy/common.py @@ -122,7 +122,9 @@ def get_rand_arr(ddesc): sdfg.apply_gpu_transformations() dace_result = sdfg(**dace_input) else: - dace_result = dp(**dace_input) + sdfg = dp.to_sdfg(**dace_input) + sdfg.simplify() + dace_result = sdfg(**dace_input) except Exception as e: dace_thrown = e diff --git a/tests/numpy/ndarray_attributes_methods_test.py b/tests/numpy/ndarray_attributes_methods_test.py index 9d8fa20534..8e892c2c38 100644 --- a/tests/numpy/ndarray_attributes_methods_test.py +++ b/tests/numpy/ndarray_attributes_methods_test.py @@ -117,7 +117,7 @@ def test_conj(A: dace.complex64[M, N, N, M]): @compare_numpy_output() -def test_sum(A: dace.float32[M, N, N, M]): +def test_sum__with_different_name(A: dace.float32[M, N, N, M]): return A.sum() @@ -166,7 +166,7 @@ def test_any(): test_min() test_argmin() test_conj() - test_sum() + test_sum__with_different_name() test_mean() test_prod() test_all() diff --git a/tests/python_frontend/fields_and_global_arrays_test.py b/tests/python_frontend/fields_and_global_arrays_test.py index b7f5e46ee9..03cb4c5915 100644 --- a/tests/python_frontend/fields_and_global_arrays_test.py +++ b/tests/python_frontend/fields_and_global_arrays_test.py @@ -585,7 +585,7 @@ def caller(): # Ensure only three globals are created sdfg = caller.to_sdfg() - assert len([k for k in sdfg.arrays if '__g' in k]) == 3 + assert len([k for k in sdfg.arrays if k.startswith('__g')]) == 3 def test_two_inner_methods(): diff --git a/tests/transformations/apply_to_test.py b/tests/transformations/apply_to_test.py index de542b758c..34ff114ac5 100644 --- a/tests/transformations/apply_to_test.py +++ b/tests/transformations/apply_to_test.py @@ -15,6 +15,7 @@ def dbladd(A: dace.float64[100, 100], B: dace.float64[100, 100]): dbl = B return A + dbl * B + @dace.program def unfusable(A: dace.float64[100], B: dace.float64[100, 100]): """Test function of two maps that can not be fused.""" @@ -57,8 +58,12 @@ def test_applyto_pattern(): transient = next(aname for aname, desc in sdfg.arrays.items() if desc.transient) access_node = next(n for n in state.nodes() if isinstance(n, dace.nodes.AccessNode) and n.data == transient) - assert MapFusion.can_be_applied_to(sdfg, first_map_exit=mult_exit, array=access_node, second_map_entry=add_entry) - + assert MapFusion.can_be_applied_to( + sdfg, + first_map_exit=mult_exit, + array=access_node, + second_map_entry=add_entry + ) MapFusion.apply_to(sdfg, first_map_exit=mult_exit, array=access_node, second_map_entry=add_entry) assert len([node for node in state.nodes() if isinstance(node, dace.nodes.MapEntry)]) == 1 diff --git a/tests/transformations/mapfusion_data_races_test.py b/tests/transformations/mapfusion_data_races_test.py index e765ec6978..ff87fd61ec 100644 --- a/tests/transformations/mapfusion_data_races_test.py +++ b/tests/transformations/mapfusion_data_races_test.py @@ -41,6 +41,13 @@ def rw_data_race_3(A: dace.float64[20], B: dace.float64[20]): A[:10] += 3.0 * offset(A[:11]) +@dace.program +def rw_data_race_4(A: dace.float64[20], B: dace.float64[20]): + # This is potentially fusable + A += B + A *= 2.0 + + def test_rw_data_race(): sdfg = rw_data_race.to_sdfg(simplify=True) sdfg.apply_transformations_repeated(MapFusion) @@ -50,8 +57,9 @@ def test_rw_data_race(): def test_rw_data_race_2_mf(): sdfg = rw_data_race_2.to_sdfg(simplify=True) - sdfg.apply_transformations_repeated(MapFusion) + nb_applied = sdfg.apply_transformations_repeated(MapFusion) map_entry_nodes = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, nodes.MapEntry)] + assert nb_applied > 0 assert (len(map_entry_nodes) > 1) @@ -69,8 +77,27 @@ def test_rw_data_race_3_sgf(): assert (len(map_entry_nodes) > 1) +def test_rw_data_race_3_mf(): + sdfg = rw_data_race_3.to_sdfg(simplify=True) + nb_applied = sdfg.apply_transformations_repeated(MapFusion) + map_entry_nodes = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, nodes.MapEntry)] + assert (len(map_entry_nodes) > 1) + assert nb_applied > 0 + + +def test_rw_data_race_4_mf(): + # It is technically possible to fuse it, because there is only a point wise dependency. + # However, it is very hard to detect and handle correct. + sdfg = rw_data_race_4.to_sdfg(simplify=True) + sdfg.apply_transformations_repeated(MapFusion) + map_entry_nodes = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, nodes.MapEntry)] + assert (len(map_entry_nodes) >= 1) + + if __name__ == "__main__": test_rw_data_race() - test_rw_data_race_2_mf() test_rw_data_race_2_sgf() + test_rw_data_race_2_mf() test_rw_data_race_3_sgf() + test_rw_data_race_3_mf() + test_rw_data_race_4_mf() diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 724c8c97ee..1e7678c3aa 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -1,12 +1,120 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +from typing import Any, Union, Tuple, Optional + import numpy as np import os import dace -from dace.transformation.dataflow import MapFusion +import copy +import uuid +import pytest + +from dace import SDFG, SDFGState +from dace.sdfg import nodes +from dace.transformation.dataflow import MapFusion, MapExpansion + + +def count_node(sdfg: SDFG, node_type): + nb_nodes = 0 + for rsdfg in sdfg.all_sdfgs_recursive(): + for state in sdfg.states(): + for node in state.nodes(): + if isinstance(node, node_type): + nb_nodes += 1 + return nb_nodes + +def apply_fusion( + sdfg: SDFG, + removed_maps: Union[int, None] = None, + final_maps: Union[int, None] = None, + unspecific: bool = False, + apply_once: bool = False, + strict_dataflow: bool = True, +) -> SDFG: + """Applies the Map fusion transformation. + + The function checks that the number of maps has been reduced, it is also possible + to specify the number of removed maps. It is also possible to specify the final + number of maps. + If `unspecific` is set to `True` then the function will just apply the + transformation and not check if maps were removed at all. + If `strict_dataflow` is set to `True`, the default, then the function will perform + the fusion in strict dataflow mode. + """ + org_sdfg = copy.deepcopy(sdfg) + num_maps_before = None if unspecific else count_node(sdfg, nodes.MapEntry) + + try: + with dace.config.temporary_config(): + dace.Config.set("optimizer", "match_exception", value=True) + if apply_once: + sdfg.apply_transformations( + MapFusion(strict_dataflow=strict_dataflow), + validate=True, + validate_all=True + ) + else: + sdfg.apply_transformations_repeated( + MapFusion(strict_dataflow=strict_dataflow), + validate=True, + validate_all=True + ) + except: + org_sdfg.view() + sdfg.view() + raise + + if unspecific: + return sdfg + + num_maps_after = count_node(sdfg, nodes.MapEntry) + has_processed = False + if removed_maps is not None: + has_processed = True + rm = num_maps_before - num_maps_after + if not (rm == removed_maps): + sdfg.view() + assert rm == removed_maps, f"Expected to remove {removed_maps} but removed {rm}" + if final_maps is not None: + has_processed = True + if not (final_maps == num_maps_after): + sdfg.view() + assert final_maps == num_maps_after, f"Expected that only {final_maps} maps remain, but there are sill {num_maps_after}." + if not has_processed: + if not (num_maps_after < num_maps_before): + sdfg.view() + assert num_maps_after < num_maps_before, f"Maps after: {num_maps_after}; Maps before: {num_maps_before}" + return sdfg @dace.program -def fusion(A: dace.float32[10, 20], B: dace.float32[10, 20], out: dace.float32[1]): +def fusion_simple(A: dace.float32[10, 20], B: dace.float32[10, 20], out: dace.float32[1]): + tmp = dace.define_local([10, 20], dtype=A.dtype) + tmp_2 = dace.define_local([10, 20], dtype=A.dtype) + for i, j in dace.map[0:10, 0:20]: + with dace.tasklet: + a << A[i, j] + b >> tmp[i, j] + + b = a * a + + for i, j in dace.map[0:10, 0:20]: + with dace.tasklet: + a << tmp[i, j] + b << B[i, j] + c >> tmp_2[i, j] + + c = a + b + + for i, j in dace.map[0:10, 0:20]: + with dace.tasklet: + a << tmp_2[i, j] + b >> out(1, lambda a, b: a + b)[0] + + b = a + + +@dace.program +def fusion_rename(A: dace.float32[10, 20], B: dace.float32[10, 20], out: dace.float32[1]): tmp = dace.define_local([10, 20], dtype=A.dtype) tmp_2 = dace.define_local([10, 20], dtype=A.dtype) for i, j in dace.map[0:10, 0:20]: @@ -66,12 +174,99 @@ def fusion_chain(A: dace.float32[10, 20], B: dace.float32[10, 20]): B[:] = tmp2 + 5 +@dace.program +def fusion_with_transient(A: dace.float64[2, 20]): + res = np.ndarray([2, 20], dace.float64) + for i in dace.map[0:20]: + for j in dace.map[0:2]: + with dace.tasklet: + a << A[j, i] + t >> res[j, i] + t = a * a + for i in dace.map[0:20]: + for j in dace.map[0:2]: + with dace.tasklet: + t << res[j, i] + o >> A[j, i] + o = t * 2 + + +@dace.program +def fusion_shared_output(A: dace.float32[10, 20], B: dace.float32[10, 20], C: dace.float32[10, 20]): + tmp = A + 3 + B[:] = tmp * 4 + C[:] = tmp / 6 + + +@dace.program +def fusion_indirect_access(A: dace.float32[100], B: dace.float32[100], idx: dace.int32[30], out: dace.float32[30]): + tmp = (A + B * 2) + 3 + out[:] = tmp[idx] + + +def make_interstate_transient_fusion_sdfg(): + sdfg = dace.SDFG("interstate_transient_fusion") + state1 = sdfg.add_state("state1", is_start_block=True) + state2 = sdfg.add_state_after(state1, "state2") + + for name in ["A", "B", "C", "D"]: + sdfg.add_array(name, shape=(20, 20), dtype=dace.float64, transient=False) + sdfg.arrays["B"].transient = True + + A1, B1, C1 = (state1.add_access(name) for name in ["A", "B", "C"]) + state1.add_mapped_tasklet( + "map_1_1", + map_ranges={"__i0": "0:20", "__i1": "0:20"}, + inputs={"__in1": dace.Memlet("A[__i0, __i1]")}, + code="__out = __in1 + 20", + outputs={"__out": dace.Memlet("B[__i0, __i1]")}, + input_nodes={"A": A1}, + output_nodes={"B": B1}, + external_edges=True, + ) + state1.add_mapped_tasklet( + "map_2_1", + map_ranges={"__i0": "0:20", "__i1": "0:20"}, + inputs={"__in1": dace.Memlet("B[__i0, __i1]")}, + code="__out = __in1 + 10", + outputs={"__out": dace.Memlet("C[__i0, __i1]")}, + input_nodes={"B": B1}, + output_nodes={"C": C1}, + external_edges=True, + ) + + B2, D2 = (state2.add_access(name) for name in ["B", "D"]) + state2.add_mapped_tasklet( + "map_1_2", + map_ranges={"__i0": "0:20", "__i1": "0:20"}, + inputs={"__in1": dace.Memlet("B[__i0, __i1]")}, + code="__out = __in1 + 6", + outputs={"__out": dace.Memlet("D[__i0, __i1]")}, + input_nodes={"B": B2}, + output_nodes={"D": D2}, + external_edges=True, + ) + + return sdfg, state1, state2 + + def test_fusion_simple(): - sdfg = fusion.to_sdfg() - sdfg.save(os.path.join('_dacegraphs', 'before1.sdfg')) - sdfg.simplify() - sdfg.apply_transformations_repeated(MapFusion) - sdfg.save(os.path.join('_dacegraphs', 'after1.sdfg')) + sdfg = fusion_simple.to_sdfg(simplify=True) + sdfg = apply_fusion(sdfg, final_maps=1) + + A = np.random.rand(10, 20).astype(np.float32) + B = np.random.rand(10, 20).astype(np.float32) + out = np.zeros(shape=1, dtype=np.float32) + sdfg(A=A, B=B, out=out) + + diff = abs(np.sum(A * A + B) - out) + print('Difference:', diff) + assert diff <= 1e-3 + + +def test_fusion_rename(): + sdfg = fusion_rename.to_sdfg(simplify=True) + sdfg = apply_fusion(sdfg, final_maps=1) A = np.random.rand(10, 20).astype(np.float32) B = np.random.rand(10, 20).astype(np.float32) @@ -83,20 +278,43 @@ def test_fusion_simple(): assert diff <= 1e-3 +def test_fusion_shared(): + sdfg = fusion_shared_output.to_sdfg(simplify=True) + sdfg = apply_fusion(sdfg) + + A = np.random.rand(10, 20).astype(np.float32) + B = np.random.rand(10, 20).astype(np.float32) + C = np.random.rand(10, 20).astype(np.float32) + + B_res = (A + 3) * 4 + C_res = (A + 3) / 6 + sdfg(A=A, B=B, C=C) + + assert np.allclose(B_res, B) + assert np.allclose(C_res, C) + + +def test_indirect_accesses(): + sdfg = fusion_indirect_access.to_sdfg(simplify=True) + sdfg = apply_fusion(sdfg, final_maps=2) + + A = np.random.rand(100).astype(np.float32) + B = np.random.rand(100).astype(np.float32) + idx = ((np.random.rand(30) * 100) % 100).astype(np.int32) + out = np.zeros(shape=30, dtype=np.float32) + + res = ((A + B * 2) + 3)[idx] + sdfg(A=A, B=B, idx=idx, out=out) + + assert np.allclose(res, out) + + def test_multiple_fusions(): - sdfg = multiple_fusions.to_sdfg() - num_nodes_before = len([node for state in sdfg.nodes() for node in state.nodes()]) + sdfg = multiple_fusions.to_sdfg(simplify=True) sdfg.save(os.path.join('_dacegraphs', 'before2.sdfg')) sdfg.simplify() - sdfg.apply_transformations_repeated(MapFusion) - sdfg.save(os.path.join('_dacegraphs', 'after2.sdfg')) - - num_nodes_after = len([node for state in sdfg.nodes() for node in state.nodes()]) - # Ensure that the number of nodes was reduced after transformation - if num_nodes_after >= num_nodes_before: - raise RuntimeError('SDFG was not properly transformed ' - '(nodes before: %d, after: %d)' % (num_nodes_before, num_nodes_after)) + sdfg = apply_fusion(sdfg) A = np.random.rand(10, 20).astype(np.float32) B = np.zeros_like(A) @@ -113,20 +331,9 @@ def test_multiple_fusions(): def test_fusion_chain(): - sdfg = fusion_chain.to_sdfg() - sdfg.save(os.path.join('_dacegraphs', 'before3.sdfg')) + sdfg = fusion_chain.to_sdfg(simplify=True) sdfg.simplify() - sdfg.apply_transformations(MapFusion) - num_nodes_before = len([node for state in sdfg.nodes() for node in state.nodes()]) - sdfg.apply_transformations(MapFusion) - sdfg.apply_transformations(MapFusion) - sdfg.save(os.path.join('_dacegraphs', 'after3.sdfg')) - - num_nodes_after = len([node for state in sdfg.nodes() for node in state.nodes()]) - # Ensure that the number of nodes was reduced after transformation - if num_nodes_after >= num_nodes_before: - raise RuntimeError('SDFG was not properly transformed ' - '(nodes before: %d, after: %d)' % (num_nodes_before, num_nodes_after)) + sdfg = apply_fusion(sdfg, final_maps=1) A = np.random.rand(10, 20).astype(np.float32) B = np.zeros_like(A) @@ -136,29 +343,14 @@ def test_fusion_chain(): assert diff <= 1e-4 -@dace.program -def fusion_with_transient(A: dace.float64[2, 20]): - res = np.ndarray([2, 20], dace.float64) - for i in dace.map[0:20]: - for j in dace.map[0:2]: - with dace.tasklet: - a << A[j, i] - t >> res[j, i] - t = a * a - for i in dace.map[0:20]: - for j in dace.map[0:2]: - with dace.tasklet: - t << res[j, i] - o >> A[j, i] - o = t * 2 - def test_fusion_with_transient(): A = np.random.rand(2, 20) expected = A * A * 2 - sdfg = fusion_with_transient.to_sdfg() + sdfg = fusion_with_transient.to_sdfg(simplify=True) sdfg.simplify() - sdfg.apply_transformations(MapFusion) + sdfg = apply_fusion(sdfg, removed_maps=2) + sdfg(A=A) assert np.allclose(A, expected) @@ -191,7 +383,7 @@ def build_sdfg(): return sdfg sdfg = build_sdfg() - sdfg.apply_transformations(MapFusion) + sdfg = apply_fusion(sdfg) A = np.random.rand(N, K) B = np.repeat(np.nan, N) @@ -217,10 +409,12 @@ def inverted_maps(A: dace.int32[10]): sdfg(A=val0) assert np.array_equal(val0, ref) - sdfg.apply_transformations(MapFusion) + # This can not be fused + apply_fusion(sdfg, removed_maps=0) + val1 = np.ndarray((10,), dtype=np.int32) sdfg(A=val1) - assert np.array_equal(val1, ref) + assert np.array_equal(val1, ref), f"REF: {ref}; VAL: {val1}" def test_fusion_with_empty_memlet(): @@ -240,8 +434,7 @@ def inner_product(A: dace.float32[N], B: dace.float32[N], out: dace.float32[1]): out[0] += lsum sdfg = inner_product.to_sdfg(simplify=True) - count = sdfg.apply_transformations_repeated(MapFusion) - assert count == 2 + apply_fusion(sdfg, removed_maps=2) A = np.arange(1024, dtype=np.float32) B = np.arange(1024, dtype=np.float32) @@ -253,9 +446,8 @@ def inner_product(A: dace.float32[N], B: dace.float32[N], out: dace.float32[1]): def test_fusion_with_nested_sdfg_0(): - @dace.program - def fusion_with_nested_sdfg_0(A: dace.int32[10], B: dace.int32[10], C: dace.int32[10]): - tmp = np.empty([10], dtype=np.int32) + def ref(A, B, C): + tmp = np.zeros_like(A) for i in dace.map[0:10]: if C[i] < 0: tmp[i] = B[i] - A[i] @@ -263,9 +455,130 @@ def fusion_with_nested_sdfg_0(A: dace.int32[10], B: dace.int32[10], C: dace.int3 tmp[i] = B[i] + A[i] for i in dace.map[0:10]: A[i] = tmp[i] * 2 - - sdfg = fusion_with_nested_sdfg_0.to_sdfg(simplify=True) - sdfg.apply_transformations(MapFusion) + + def _make_sdfg() -> dace.SDFG: + sdfg = SDFG("fusion_with_nested_sdfg_0") + state = sdfg.add_state(is_start_block=True) + + for name in "ABCT": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["T"].transient = True + + me1, mx1 = state.add_map("first_map", ndrange={"__i0": "0:10"}) + nsdfg = state.add_nested_sdfg( + sdfg=_make_nested_sdfg(), + parent=sdfg, + inputs={"a", "b", "c"}, + outputs={"t"}, + symbol_mapping={}, + ) + + for name in "ABC": + state.add_edge( + state.add_access(name), None, + me1, "IN_" + name, + dace.Memlet(f"{name}[0:10]"), + ) + me1.add_in_connector("IN_" + name) + state.add_edge( + me1, "OUT_" + name, + nsdfg, name.lower(), + dace.Memlet(f"{name}[__i0]"), + ) + me1.add_out_connector("OUT_" + name) + state.add_edge( + nsdfg, "t", + mx1, "IN_T", + dace.Memlet("T[__i0]"), + ) + T = state.add_access("T") + state.add_edge( + mx1, "OUT_T", + T, None, + dace.Memlet("T[0:10]"), + ) + mx1.add_in_connector("IN_T") + mx1.add_out_connector("OUT_T") + + state.add_mapped_tasklet( + "comp2", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("T[__i0]")}, + code="__out = __in1 * 2", + outputs={"__out": dace.Memlet("A[__i0]")}, + input_nodes={T}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + def _make_nested_sdfg() -> dace.SDFG: + sdfg = SDFG("Nested") + + for name in "abct": + sdfg.add_scalar( + name, + dtype=dace.float64, + transient=False, + ) + + state_head = sdfg.add_state("head_state", is_start_block=True) + state_if_guard = sdfg.add_state("state_if_guard") + sdfg.add_edge( + state_head, + state_if_guard, + dace.InterstateEdge( + condition="1", + assignments={"__tmp2": "c < 0.0"}, + ) + ) + + def _make_branch_tasklet( + state: dace.SDFGState, + code: str, + ) -> None: + tasklet = state.add_tasklet( + state.label + "_tasklet", + inputs={"__in1", "__in2"}, + code=code, + outputs={"__out"}, + ) + state.add_edge( + state.add_access("b"), None, + tasklet, "__in1", + dace.Memlet("b[0]"), + ) + state.add_edge( + state.add_access("a"), None, + tasklet, "__in2", + dace.Memlet("a[0]"), + ) + state.add_edge( + tasklet, "__out", + state.add_access("t"), None, + dace.Memlet("t[0]"), + ) + + state_trueb = sdfg.add_state("true_branch") + _make_branch_tasklet(state_trueb, "__out = __in1 - __in2") + state_falseb = sdfg.add_state("false_branch") + _make_branch_tasklet(state_falseb, "__out = __in1 + __in2") + state_if_end = sdfg.add_state("if_join") + + sdfg.add_edge(state_if_guard, state_trueb, dace.InterstateEdge(condition="__tmp2")) + sdfg.add_edge(state_if_guard, state_falseb, dace.InterstateEdge(condition="not __tmp2")) + sdfg.add_edge(state_falseb, state_if_end, dace.InterstateEdge()) + sdfg.add_edge(state_trueb, state_if_end, dace.InterstateEdge()) + sdfg.validate() + return sdfg + + sdfg = _make_sdfg() + apply_fusion(sdfg) for sd in sdfg.all_sdfgs_recursive(): if sd is not sdfg: @@ -277,8 +590,25 @@ def fusion_with_nested_sdfg_0(A: dace.int32[10], B: dace.int32[10], C: dace.int3 assert isinstance(dst, dace.nodes.AccessNode) + args_ref = { + 'A': np.array(np.random.rand(10), dtype=np.float64, copy=True), + 'B': np.array(np.random.rand(10), dtype=np.float64, copy=True), + 'C': np.array(np.random.rand(10) - 0.5, dtype=np.float64, copy=True), + } + args_res = copy.deepcopy(args_ref) + + ref(**args_ref) + sdfg(**args_res) + for arg in args_ref.keys(): + arg_ref = args_ref[arg] + arg_res = args_res[arg] + assert np.allclose(arg_ref, arg_res), f"Failed in {arg}" + + def test_fusion_with_nested_sdfg_1(): - + + # As a side effect this test also ensures that dynamic consumer edges, does not + # impact fusing, i.e. allow that fusion can take place. @dace.program def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int32[10]): tmp = np.empty([10], dtype=np.int32) @@ -295,7 +625,7 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3 B[i] = tmp[i] * 2 sdfg = fusion_with_nested_sdfg_1.to_sdfg(simplify=True) - sdfg.apply_transformations(MapFusion) + apply_fusion(sdfg) if len(sdfg.states()) != 1: return @@ -310,13 +640,1406 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3 assert isinstance(src, dace.nodes.AccessNode) +def test_interstate_fusion(): + """Transient between two maps is used in another state and must become shared. + """ + sdfg, state1, state2 = make_interstate_transient_fusion_sdfg() + + A = np.random.rand(20, 20) + C = np.random.rand(20, 20) + D = np.random.rand(20, 20) + + ref_C = A + 30 + ref_D = A + 26 + + apply_fusion(sdfg, removed_maps=1) + assert sdfg.number_of_nodes() == 2 + assert len([node for node in state1.data_nodes() if node.data == "B"]) == 1 + + sdfg(A=A, C=C, D=D) + + assert np.allclose(C, ref_C) + assert np.allclose(D, ref_D) + + +def test_fuse_indirect_accesses(): + + @dace.program(auto_optimize=False) + def inner_product( + A: dace.float32[20], + B: dace.float32[20], + idx: dace.int32[20], + out: dace.float32[20], + ): + tmp1 = np.empty_like(A) + tmp2 = np.empty_like(A) + for i in dace.map[0:20]: + tmp1[i] = A[i] * B[i] + for i in dace.map[0:20]: + tmp2[i] = tmp1[i] + A[i] + for i in dace.map[0:20]: + with dace.tasklet: + __arr << tmp2(1)[:] + __idx << idx[i] + __out >> out[i] + __out = __arr[__idx] + + sdfg = inner_product.to_sdfg(simplify=True) + assert sdfg.number_of_nodes() == 1 + assert count_node(sdfg, nodes.MapEntry) == 3 + + apply_fusion(sdfg, final_maps=2) + + # The last map, with the indirect access, can not be fused, so check that. + state = next(iter(sdfg.nodes())) + assert len(list(state.sink_nodes())) == 1 + out_node = next(iter(state.sink_nodes())) + assert out_node.data == "out" + assert state.in_degree(out_node) == 1 + + # Now find the last map and the indirect access Tasklet + last_map_exit = next(iter(state.in_edges(out_node))).src + last_map_entry = state.entry_node(last_map_exit) + assert isinstance(last_map_exit, nodes.MapExit) + assert state.in_degree(last_map_exit) == 1 + + indirect_access_tasklet = next(iter(state.in_edges(last_map_exit))).src + assert isinstance(indirect_access_tasklet, nodes.Tasklet) + assert indirect_access_tasklet.code == "__out = __arr[__idx]" # TODO: Regex with connectors + + # The tasklet can only be connected to a map entry. + assert all(in_edge.src is last_map_entry for in_edge in state.in_edges(indirect_access_tasklet)) + + +def make_correction_offset_sdfg( + range_read: bool, + second_read_start: int, +) -> SDFG: + """Make the SDFGs for the `test_offset_correction_*` tests. + + Args: + range_read: If `True` then a range is read in the second map. + if `False` then only a scalar is read. + second_read_start: Where the second map should start reading. + """ + sdfg = SDFG("offset_correction_test") + state = sdfg.add_state(is_start_block=True) + shapes = { + "A": (20, 10), + "B": (20, 8), + "C": (20, 2) if range_read else (20, 1), + } + descs = {} + for name, shape in shapes.items(): + _, desc = sdfg.add_array(name, shape, dace.float64, transient=False) + descs[name] = desc + sdfg.arrays["B"].transient = True + A, B, C = (state.add_access(name) for name in sorted(shapes.keys())) + + state.add_mapped_tasklet( + "first_map", + map_ranges={"i": "0:20", "j": "2:8"}, + inputs={"__in1": dace.Memlet("A[i, j]")}, + code="__out = __in1 + 1.0", + outputs={"__out": dace.Memlet("B[i, j]")}, + input_nodes={"A": A}, + output_nodes={"B": B}, + external_edges=True, + ) + state.add_mapped_tasklet( + "second_map", + map_ranges=( + {"i": "0:20", "k": "0:2"} + if range_read + else {"i": "0:20"} + ), + inputs={"__in1": dace.Memlet(f"B[i, {second_read_start}{'+k' if range_read else ''}]")}, + code="__out = __in1", + outputs={"__out": dace.Memlet(f"C[i, {'k' if range_read else '0'}]")}, + input_nodes={"B": B}, + output_nodes={"C": C}, + external_edges=True, + ) + sdfg.validate() + assert sdfg.apply_transformations_repeated(MapExpansion, validate_all=True) > 0 + return sdfg + + +def test_offset_correction_range_read(): + + np.random.seed(42) + A = np.random.rand(20, 10) + C = np.zeros((20, 2)) + exp = (A + 1.0)[:, 3:5].copy() + + sdfg = make_correction_offset_sdfg(range_read=True, second_read_start=3) + + sdfg(A=A, C=C) + assert np.allclose(C, exp) + C[:] = 0.0 + + apply_fusion(sdfg) + + sdfg(A=A, C=C) + assert np.allclose(C, exp) + + +def test_offset_correction_scalar_read(): + + np.random.seed(42) + A = np.random.rand(20, 10) + C = np.zeros((20, 1)) + exp = (A + 1.0)[:, 3].copy().reshape((-1, 1)) + + sdfg = make_correction_offset_sdfg(range_read=False, second_read_start=3) + + sdfg(A=A, C=C) + assert np.allclose(C, exp) + C[:] = 0.0 + + apply_fusion(sdfg) + + sdfg(A=A, C=C) + assert np.allclose(C, exp) + + +def test_offset_correction_empty(): + + # Because the second map starts reading from 1, but the second map only + # starts writing from 2 there is no overlap and it can not be fused. + # NOTE: This computation is useless. + sdfg = make_correction_offset_sdfg(range_read=True, second_read_start=1) + + apply_fusion(sdfg, removed_maps=0) + + +def test_different_offsets(): + + def exptected(A, B): + N, M = A.shape + return (A + 1) + B[1:(N+1), 2:(M+2)] + + def _make_sdfg(N: int, M: int) -> dace.SDFG: + sdfg = dace.SDFG("test_different_access") + names = ["A", "B", "__tmp", "__return"] + def_shape = (N, M) + sshape = {"B": (N+1, M+2), "__tmp": (N+1, M+1)} + for name in names: + sdfg.add_array( + name, + shape=sshape.get(name, def_shape), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["__tmp"].transient = True + + state = sdfg.add_state(is_start_block=True) + A, B, _tmp, _return = (state.add_access(name) for name in names) + + state.add_mapped_tasklet( + "comp1", + map_ranges={"__i0": f"0:{N}", "__i1": f"0:{M}"}, + inputs={"__in": dace.Memlet("A[__i0, __i1]")}, + code="__out = __in + 1.0", + outputs={"__out": dace.Memlet("__tmp[__i0 + 1, __i1 + 1]")}, + input_nodes={A}, + output_nodes={_tmp}, + external_edges=True, + ) + state.add_mapped_tasklet( + "comp2", + map_ranges={"__i0": f"0:{N}", "__i1": f"0:{M}"}, + inputs={ + "__in1": dace.Memlet("__tmp[__i0 + 1, __i1 + 1]"), + "__in2": dace.Memlet("B[__i0 + 1, __i1 + 2]"), + }, + code="__out = __in1 + __in2", + outputs={"__out": dace.Memlet("__return[__i0, __i1]")}, + input_nodes={_tmp, B}, + output_nodes={_return}, + external_edges=True, + ) + + sdfg.validate() + return sdfg + + N, M = 14, 17 + sdfg = _make_sdfg(N, M) + apply_fusion(sdfg, final_maps=1) + + A = np.array(np.random.rand(N, M), dtype=np.float64, copy=True) + B = np.array(np.random.rand(N + 1, M + 2), dtype=np.float64, copy=True) + + ref = exptected(A, B) + res = sdfg(A=A, B=B) + assert np.allclose(ref, res) + + +def _make_strict_dataflow_sdfg_pointwise( + input_data: str = "A", + intermediate_data: str = "T", + output_data: Optional[str] = None, + input_read: str = "__i0", + output_write: Optional[str] = None, +) -> Tuple[dace.SDFG, dace.SDFGState]: + """ + Creates the SDFG for the strict data flow tests. + + The SDFG will read and write into `A`, but it is pointwise, thus the Maps can + be fused. Furthermore, this particular SDFG guarantees that no data race occurs. + """ + if output_data is None: + output_data = input_data + if output_write is None: + output_write = input_read + + sdfg = dace.SDFG(f"strict_dataflow_sdfg_pointwise_{str(uuid.uuid1()).replace('-', '_')}") + state = sdfg.add_state(is_start_block=True) + for name in {input_data, intermediate_data, output_data}: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + + if intermediate_data not in {input_data, output_data}: + sdfg.arrays[intermediate_data].transient = True + + input_node, intermediate_node, output_node = (state.add_access(name) for name in [input_data, intermediate_data, output_data]) + + state.add_mapped_tasklet( + "first_comp", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet(f"{input_data}[{input_read}]")}, + code="__out = __in1 + 2.0", + outputs={"__out": dace.Memlet(f"{intermediate_data}[__i0]")}, + input_nodes={input_node}, + output_nodes={intermediate_node}, + external_edges=True, + ) + state.add_mapped_tasklet( + "second_comp", + map_ranges={"__i1": "0:10"}, + inputs={"__in1": dace.Memlet(f"{intermediate_data}[__i1]")}, + code="__out = __in1 + 3.0", + outputs={"__out": dace.Memlet(f"{output_data}[{output_write}]")}, + input_nodes={intermediate_node}, + output_nodes={output_node}, + external_edges=True, + ) + sdfg.validate() + return sdfg, state + + +def test_fusion_strict_dataflow_pointwise(): + sdfg, state = _make_strict_dataflow_sdfg_pointwise(input_data="A") + + # However, if strict dataflow is disabled, then it will be able to fuse. + apply_fusion(sdfg, removed_maps=1, strict_dataflow=False) + + +def test_fusion_strict_dataflow_not_pointwise(): + sdfg, state = _make_strict_dataflow_sdfg_pointwise( + input_data="A", + input_read="__i0", + output_write="9 - __i0", + ) + + # Because the dependency is not pointwise even disabling strict dataflow + # will not make it work. + apply_fusion(sdfg, removed_maps=0, strict_dataflow=False) + + +def test_fusion_dataflow_intermediate(): + sdfg, _ = _make_strict_dataflow_sdfg_pointwise( + input_data="A", + intermediate_data="O", + output_data="O", + ) + apply_fusion(sdfg, removed_maps=0, strict_dataflow=True) + + # Because the intermediate is also output of the second map it is not possible + # to fuse even without strict dataflow mode. + apply_fusion(sdfg, removed_maps=0, strict_dataflow=False) + + +def test_fusion_dataflow_intermediate_2(): + # The transformation applies for two reasons, first reading and writing `A` + # is pointwise. Furthermore, there is no further access to `A` after the + # intermediate node. Note that if the second map would also have an output + # that refers to `A` then the transformation would not apply regardless + # of the strict dataflow mode. + sdfg, state = _make_strict_dataflow_sdfg_pointwise( + input_data="A", + intermediate_data="A", + output_data="O", + ) + apply_fusion(sdfg, removed_maps=1, strict_dataflow=True) + map_exit = next(iter(node for node in state.nodes() if isinstance(node, nodes.MapExit))) + assert state.out_degree(map_exit) == 2 + assert {"A", "O"} == {edge.dst.data for edge in state.out_edges(map_exit) if isinstance(edge.dst, nodes.AccessNode)} + + +def test_fusion_dataflow_intermediate_3(): + # This is exactly the same situation as in `test_fusion_dataflow_intermediate_2()` + # with the exception that now the access to `A` is no longer pointwise, thus + # the transformation does not apply. Note that this SDFG is wrong, it is only + # here to show that the case is detected. + sdfg, state = _make_strict_dataflow_sdfg_pointwise( + input_data="A", + intermediate_data="A", + output_data="O", + input_read="9 - __i0", + output_write="__i0", + ) + apply_fusion(sdfg, removed_maps=0, strict_dataflow=True) + + +def test_fusion_dataflow_intermediate_downstream(): + # Because the intermediate `T` is used downstream again, + # the transformation can not apply. + sdfg, state = _make_strict_dataflow_sdfg_pointwise( + input_data="A", + intermediate_data="T", + output_data="output_1", + ) + sdfg.arrays["output_1"].transient = False + sdfg.arrays["T"].transient = True + output_1 = next(iter(dnode for dnode in state.sink_nodes())) + assert isinstance(output_1, nodes.AccessNode) and output_1.data == "output_1" + + # Make the real output node. + sdfg.arrays["O"] = sdfg.arrays["A"].clone() + state.add_mapped_tasklet( + "downstream_computation", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("output_1[__i0]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("T[__i0]")}, + input_nodes={output_1}, + external_edges=True, + ) + + # Make another state where `T` is written back, such that it is not dead data flow. + state2 = sdfg.add_state_after(state) + sdfg.add_datadesc("output_2", sdfg.arrays["output_1"].clone()) + state2.add_nedge( + state2.add_access("T"), + state2.add_access("output_2"), + sdfg.make_array_memlet("T"), + ) + sdfg.validate() + + apply_fusion(sdfg, removed_maps=0, strict_dataflow=True) + + # However without strict dataflow, the merge is possible. + apply_fusion(sdfg, removed_maps=1, strict_dataflow=False) + assert state.in_degree(output_1) == 1 + assert state.out_degree(output_1) == 1 + assert all(isinstance(edge.src, nodes.MapExit) for edge in state.in_edges(output_1)) + assert all(isinstance(edge.dst, nodes.MapEntry) for edge in state.out_edges(output_1)) + + upper_map_exit = next(iter(edge.src for edge in state.in_edges(output_1))) + assert isinstance(upper_map_exit, nodes.MapExit) + assert state.out_degree(upper_map_exit) == 2 + assert {"T", "output_1"} == {edge.dst.data for edge in state.out_edges(upper_map_exit) if isinstance(edge.dst, nodes.AccessNode)} + + +def test_fusion_non_strict_dataflow_implicit_dependency(): + """ + This test simulates if the fusion respect implicit dependencies, given by access nodes. + + This test simulates a situation that could arise if non strict dataflow is enabled. + The test ensures that the fusion does not continue fusing in this situation. + """ + sdfg = dace.SDFG("fusion_strict_dataflow_implicit_dependency_sdfg") + state = sdfg.add_state(is_start_block=True) + names = ["A", "B", "T1", "T2", "C"] + for name in names: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["T1"].transient = True + sdfg.arrays["T2"].transient = True + + me, mx = state.add_map( + "first_map", + ndrange={"__i0": "0:10"} + ) + tskl1 = state.add_tasklet( + "tskl1", + inputs={"__in1", "__in2"}, + code="__out = __in1 * __in2", + outputs={"__out"} + ) + tskl2 = state.add_tasklet( + "tskl2", + inputs={"__in1", "__in2"}, + code="__out = (__in1 + __in2) / 2", + outputs={"__out"} + ) + A, B, T1, T2 = (state.add_access(name) for name in names[:-1]) + + state.add_edge(A, None, me, "IN_A", dace.Memlet("A[0:10]")) + state.add_edge(B, None, me, "IN_B", dace.Memlet("B[0:10]")) + me.add_in_connector("IN_A") + me.add_in_connector("IN_B") + + state.add_edge(me, "OUT_A", tskl1, "__in1", dace.Memlet("A[__i0]")) + state.add_edge(me, "OUT_B", tskl1, "__in2", dace.Memlet("B[__i0]")) + state.add_edge(me, "OUT_A", tskl2, "__in1", dace.Memlet("A[__i0]")) + state.add_edge(me, "OUT_B", tskl2, "__in2", dace.Memlet("B[__i0]")) + me.add_out_connector("OUT_A") + me.add_out_connector("OUT_B") + + state.add_edge(tskl1, "__out", mx, "IN_T1", dace.Memlet("T1[__i0]")) + state.add_edge(tskl2, "__out", mx, "IN_T2", dace.Memlet("T2[__i0]")) + mx.add_in_connector("IN_T1") + mx.add_in_connector("IN_T2") + + state.add_edge(mx, "OUT_T1", T1, None, dace.Memlet("T1[0:10]")) + state.add_edge(mx, "OUT_T2", T2, None, dace.Memlet("T2[0:10]")) + mx.add_out_connector("OUT_T1") + mx.add_out_connector("OUT_T2") + + state.add_mapped_tasklet( + "second_map", + map_ranges={"__in0": "0:10"}, + inputs={"__in1": dace.Memlet("T1[__i0]")}, + code="if __in1 < 0.5:\n\t__out = 100.", + outputs={"__out": dace.Memlet("T2[__i0]", dynamic=True)}, + input_nodes={T1}, + external_edges=True, + ) + + state2 = sdfg.add_state_after(state) + state2.add_edge( + state2.add_access("T2"), + None, + state2.add_access("C"), + None, + dace.Memlet("T2[0:10] -> [0:10]"), + ) + sdfg.validate() + + apply_fusion(sdfg, removed_maps=0, strict_dataflow=False) + + +def _make_inner_conflict_shared_scalar( + has_conflict: bool, +) -> dace.SDFG: + """Generate the SDFG for tests with the inner dependency. + + If `has_conflict` is `True` then a transient scalar is used inside both Map bodies. + Therefore, `MapFusion` should not be able to fuse them. + In case `has_conflict` is `False` then different scalars are used which allows + fusing the two maps. + """ + sdfg = dace.SDFG( + "inner_map_dependency_sdfg" + if has_conflict + else "inner_map_dependency_resolved_sdfg" + ) + state = sdfg.add_state(is_start_block=True) + + name_arrays = ["A", "T", "C"] + for aname in name_arrays: + sdfg.add_array( + aname, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["T"].transient = True + + name_scalars = ["s", "s"] if has_conflict else ["s1", "s2"] + for sname in set(name_scalars): + sdfg.add_scalar( + sname, + dtype=dace.float64, + transient=True, + ) + A, T, C = (state.add_access(aname) for aname in name_arrays) + s1, s2 = (state.add_access(sname) for sname in name_scalars) + + me1, mx1 = state.add_map( + "map_1", + ndrange={"__i0": "0:10"}, + ) + tsklt1 = state.add_tasklet( + "tskl1", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 + 1.0", + ) + + # Create the first map series. + state.add_edge( + A, None, + me1, "IN_A", + dace.Memlet("A[0:10]") + ) + me1.add_in_connector("IN_A") + state.add_edge( + me1, "OUT_A", + s1, None, + dace.Memlet("A[__i0] -> [0]") + ) + me1.add_out_connector("OUT_A") + state.add_edge( + s1, None, + tsklt1, "__in1", + dace.Memlet(f"{s1.data}[0]") + ) + state.add_edge( + tsklt1, "__out", + mx1, "IN_T", + dace.Memlet("T[__i0]") + ) + mx1.add_in_connector("IN_T") + state.add_edge( + mx1, "OUT_T", + T, None, + dace.Memlet("T[0:10]") + ) + mx1.add_out_connector("OUT_T") + + # Create the second map. + me2, mx2 = state.add_map( + "map_2", + ndrange={"__i0": "0:10"}, + ) + tsklt2 = state.add_tasklet( + "tskl2", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 + 3.0", + ) + + state.add_edge( + T, None, + me2, "IN_T", + dace.Memlet("T[0:10]") + ) + me2.add_in_connector("IN_T") + state.add_edge( + me2, "OUT_T", + s2, None, + dace.Memlet("T[__i0]") + ) + me2.add_out_connector("OUT_T") + state.add_edge( + s2, None, + tsklt2, "__in1", + dace.Memlet(f"{s2.data}[0]") + ) + state.add_edge( + tsklt2, "__out", + mx2, "IN_C", + dace.Memlet("C[__i0]") + ) + mx2.add_in_connector("IN_C") + state.add_edge( + mx2, "OUT_C", + C, None, + dace.Memlet("C[0:10]") + ) + mx2.add_out_connector("OUT_C") + sdfg.validate() + return sdfg + + +def test_inner_map_dependency(): + # Because the scalar is not shared the maps can not be fused. + sdfg = _make_inner_conflict_shared_scalar(has_conflict=True) + apply_fusion(sdfg, removed_maps=0, final_maps=2) + + +def test_inner_map_dependency_resolved(): + # Because the scalars are different, the scalar + sdfg = _make_inner_conflict_shared_scalar(has_conflict=False) + apply_fusion(sdfg, removed_maps=1, final_maps=1) + + +def _impl_fusion_intermediate_different_access( + modified_shape: bool, + traditional_memlet_direction: bool +): + def ref(A, B): + T = np.zeros((A.shape[0] + 1, 2)) + for i in range(A.shape[0]): + T[i + 1, 0] = A[i] * 2 + T[i + 1, 1] = A[i] / 2 + for j in range(A.shape[0]): + B[j] = np.sin(T[j+1, 1]) + + sdfg = dace.SDFG("fusion_intermediate_different_access_sdfg") + state = sdfg.add_state(is_start_block=True) + for name in "AB": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.add_array( + "T", + shape=(11, 2), + dtype=dace.float64, + transient=True, + ) + + # For this intermediate, which essentially represents `[A[i] * 2, A[i] / 2]` in + # the reference above, there are two important remarks: + # - It exists because one data stream, i.e. `T[i + 1, 1]` would be dead data flow + # and currently the transformation can not handle this. + # - The strange shape is because the transformation can not handle this case. + # This is a limitation of the implementation. + sdfg.add_array( + "temp", + shape=( + (1, 2,) + if modified_shape + else (2,) + ), + dtype=dace.float64, + transient=True, + ) + + A, B, T, temp = (state.add_access(name) for name in ["A", "B", "T", "temp"]) + + me1, mx1 = state.add_map( + "first_map", + ndrange={"__i0": "0:10"}, + ) + + state.add_edge( + A, None, + me1, "IN_A", + dace.Memlet("A[0:10]") + ) + me1.add_in_connector("IN_A") + me1.add_out_connector("OUT_A") + + tsklt1_1 = state.add_tasklet( + "tsklt1_1", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 * 2.0", + ) + state.add_edge( + me1, "OUT_A", + tsklt1_1, "__in1", + dace.Memlet("A[__i0]") + ) + state.add_edge( + tsklt1_1, "__out", + temp, None, + dace.Memlet( + "temp[0, 0]" + if modified_shape + else "temp[0]" + ) + ) + + tsklt1_2 = state.add_tasklet( + "tsklt1_2", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 / 2.0", + ) + state.add_edge( + me1, "OUT_A", + tsklt1_2, "__in1", + dace.Memlet("A[__i0]") + ) + state.add_edge( + tsklt1_2, "__out", + temp, None, + dace.Memlet( + "temp[0, 1]" + if modified_shape + else "temp[1]" + ) + ) + + temp_subset = ( + "0, 0:2" + if modified_shape + else "0:2" + ) + T_subset = "__i0 + 1, 0:2" + + if traditional_memlet_direction: + mem_data = "T" + mem_subset = T_subset + mem_other_subset = temp_subset + else: + mem_data = "temp" + mem_subset = temp_subset + mem_other_subset = T_subset + + state.add_edge( + temp, None, + mx1, "IN_temp", + dace.Memlet(f"{mem_data}[{mem_subset}] -> [{mem_other_subset}]") + ) + state.add_edge( + mx1, "OUT_temp", + T, None, + dace.Memlet("T[1:11, 0:2]") + ) + mx1.add_in_connector("IN_temp") + mx1.add_out_connector("OUT_temp") + + state.add_mapped_tasklet( + "comp2", + map_ranges={"__i1": "0:10"}, + inputs={"__in1": dace.Memlet("T[__i1 + 1, 1]")}, + code="__out = math.sin(__in1)", + outputs={"__out": dace.Memlet("B[__i1]")}, + input_nodes={T}, + output_nodes={B}, + external_edges=True, + ) + sdfg.validate() + + apply_fusion(sdfg, removed_maps=1, final_maps=1) + + args_ref = { + 'A': np.array(np.random.rand(10), dtype=np.float64, copy=True), + 'B': np.array(np.random.rand(10), dtype=np.float64, copy=True), + } + args_res = copy.deepcopy(args_ref) + + ref(**args_ref) + sdfg(**args_res) + for arg in args_ref.keys(): + arg_ref = args_ref[arg] + arg_res = args_res[arg] + assert np.allclose(arg_ref, arg_res) + + +def test_fusion_intermediate_different_access(): + _impl_fusion_intermediate_different_access(modified_shape=False, traditional_memlet_direction=False) + + +def test_fusion_intermediate_different_access_2(): + _impl_fusion_intermediate_different_access(modified_shape=False, traditional_memlet_direction=True) + + +def test_fusion_intermediate_different_access_mod_shape(): + _impl_fusion_intermediate_different_access(modified_shape=True, traditional_memlet_direction=False) + + +def test_fusion_intermediate_different_access_mod_shape_2(): + _impl_fusion_intermediate_different_access(modified_shape=True, traditional_memlet_direction=True) + + +@pytest.mark.skip(reason="This feature is not yet fully supported.") +def test_fusion_multiple_producers_consumers(): + """Multiple producer and consumer nodes. + + This test is very similar to the `test_fusion_intermediate_different_access()` + and `test_fusion_intermediate_different_access_mod_shape()` test, with the + exception that now full data is used in the second map. + However, currently `MapFusion` only supports a single producer, thus this test can + not run. + """ + def ref(A, B): + T = np.zeros((A.shape[0], 2)) + for i in range(A.shape[0]): + T[i, 0] = A[i] * 2 + T[i, 1] = A[i] / 2 + for j in range(A.shape[0]): + B[j] = np.sin(T[j, 1]) + np.cos(T[j, 0]) + + sdfg = dace.SDFG("fusion_multiple_producers_consumers_sdfg") + state = sdfg.add_state(is_start_block=True) + for name in "AB": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.add_array( + "T", + shape=(10, 2), + dtype=dace.float64, + transient=True, + ) + + A, B, T = (state.add_access(name) for name in ["A", "B", "T"]) + + me1, mx1 = state.add_map( + "first_map", + ndrange={"__i0": "0:10"}, + ) + + state.add_edge( + A, None, + me1, "IN_A", + dace.Memlet("A[0:10]") + ) + me1.add_in_connector("IN_A") + me1.add_out_connector("OUT_A") + + tsklt1_1 = state.add_tasklet( + "tsklt1_1", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 * 2.0", + ) + state.add_edge( + me1, "OUT_A", + tsklt1_1, "__in1", + dace.Memlet("A[__i0]") + ) + state.add_edge( + tsklt1_1, "__out", + mx1, "IN_T", + dace.Memlet("T[__i0, 0]") + ) + + tsklt1_2 = state.add_tasklet( + "tsklt1_2", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 / 2.0", + ) + state.add_edge( + me1, "OUT_A", + tsklt1_2, "__in1", + dace.Memlet("A[__i0]") + ) + state.add_edge( + tsklt1_2, "__out", + mx1, "IN_T", + dace.Memlet("T[__i0, 1]") + ) + mx1.add_in_connector("IN_T") + + state.add_edge( + mx1, "OUT_T", + T, None, + dace.Memlet("T[0:10, 0:2]"), + ) + mx1.add_out_connector("OUT_T") + + state.add_mapped_tasklet( + "comp2", + map_ranges={"__i1": "0:10"}, + inputs={ + "__in1": dace.Memlet("T[__i1, 1]"), + "__in2": dace.Memlet("T[__i1, 0]"), + }, + code="__out = math.sin(__in1) + math.cos(__in2)", + outputs={"__out": dace.Memlet("B[__i1]")}, + input_nodes={T}, + output_nodes={B}, + external_edges=True, + ) + sdfg.validate() + + apply_fusion(sdfg, removed_maps=1, final_maps=1) + + args_ref = { + 'A': np.array(np.random.rand(10), dtype=np.float64, copy=True), + 'B': np.array(np.random.rand(10), dtype=np.float64, copy=True), + } + args_res = copy.deepcopy(args_ref) + + ref(**args_ref) + sdfg(**args_res) + for arg in args_ref.keys(): + arg_ref = args_ref[arg] + arg_res = args_res[arg] + assert np.allclose(arg_ref, arg_res) + + +def test_fusion_multiple_consumers(): + """The intermediate is consumed multiple times in the second map. + """ + def ref(A, B, C): + T = np.zeros_like(A) + for i in range(A.shape[0]): + T[i] = np.sin(A[i] * 2) + for j in range(A.shape[0]): + B[j] = T[j] * 3. + C[j] = T[j] - 1. + + sdfg = dace.SDFG("fusion_multiple_consumers_sdfg") + state = sdfg.add_state(is_start_block=True) + for name in "ABCT": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["T"].transient = True + + A, B, C, T = (state.add_access(name) for name in ["A", "B", "C", "T"]) + + state.add_mapped_tasklet( + "comp1", + map_ranges={"__i1": "0:10"}, + inputs={ + "__in1": dace.Memlet("A[__i1]"), + }, + code="__out = math.sin(2 * __in1)", + outputs={"__out": dace.Memlet("T[__i1]")}, + input_nodes={A}, + output_nodes={T}, + external_edges=True, + ) + + me2, mx2 = state.add_map( + "second_map", + ndrange={"__i0": "0:10"}, + ) + + state.add_edge( + T, None, + me2, "IN_T", + dace.Memlet("T[0:10]", volume=20) + ) + me2.add_in_connector("IN_T") + me2.add_out_connector("OUT_T") + + tsklt2_1 = state.add_tasklet( + "tsklt2_1", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 * 3.0", + ) + state.add_edge( + me2, "OUT_T", + tsklt2_1, "__in1", + dace.Memlet("T[__i0]") + ) + state.add_edge( + tsklt2_1, "__out", + mx2, "IN_B", + dace.Memlet("B[__i0]") + ) + + tsklt2_2 = state.add_tasklet( + "tsklt2_2", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 - 1.0", + ) + state.add_edge( + me2, "OUT_T", + tsklt2_2, "__in1", + dace.Memlet("T[__i0]") + ) + state.add_edge( + tsklt2_2, "__out", + mx2, "IN_C", + dace.Memlet("C[__i0]") + ) + mx2.add_in_connector("IN_B") + mx2.add_in_connector("IN_C") + + state.add_edge( + mx2, "OUT_B", + B, None, + dace.Memlet("B[0:10]"), + ) + state.add_edge( + mx2, "OUT_C", + C, None, + dace.Memlet("C[0:10]"), + ) + mx2.add_out_connector("OUT_B") + mx2.add_out_connector("OUT_C") + sdfg.validate() + + apply_fusion(sdfg, removed_maps=1, final_maps=1) + + args_ref = { + 'A': np.array(np.random.rand(10), dtype=np.float64, copy=True), + 'B': np.array(np.random.rand(10), dtype=np.float64, copy=True), + 'C': np.array(np.random.rand(10), dtype=np.float64, copy=True), + } + args_res = copy.deepcopy(args_ref) + + ref(**args_ref) + sdfg(**args_res) + for arg in args_ref.keys(): + arg_ref = args_ref[arg] + arg_res = args_res[arg] + assert np.allclose(arg_ref, arg_res) + + +def test_fusion_different_global_accesses(): + + def ref(A, B): + T = np.zeros_like(A) + for i in range(10): + T[i] = A[i] - B[i + 1] + for i in range(10): + A[i] = np.sin(T[i]) + B[i + 1] = np.cos(T[i]) + + sdfg = dace.SDFG("fusion_different_global_accesses_sdfg") + state = sdfg.add_state(is_start_block=True) + for name in "ABT": + sdfg.add_array( + name, + shape=(11,), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["T"].transient = True + T = state.add_access("T") + + state.add_mapped_tasklet( + "first_comp", + map_ranges={"__i0": "0:10"}, + inputs={ + "__in1": dace.Memlet("A[__i0]"), + "__in2": dace.Memlet("B[__i0 + 1]") + }, + code="__out = __in1 - __in2", + outputs={"__out": dace.Memlet("T[__i0]")}, + output_nodes={T}, + external_edges=True, + ) + state.add_mapped_tasklet( + "second_comp", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("T[__i0]")}, + code="__out1 = math.sin(__in1)\n__out2 = math.cos(__in1)", + outputs={ + "__out1": dace.Memlet("A[__i0]"), + "__out2": dace.Memlet("B[__i0 + 1]"), + }, + input_nodes={T}, + external_edges=True, + ) + sdfg.validate() + + apply_fusion(sdfg, removed_maps=1, final_maps=1) + + args_ref = { + 'A': np.array(np.random.rand(11), dtype=np.float64, copy=True), + 'B': np.array(np.random.rand(11), dtype=np.float64, copy=True), + } + args_res = copy.deepcopy(args_ref) + + ref(**args_ref) + sdfg(**args_res) + for arg in args_ref.keys(): + arg_ref = args_ref[arg] + arg_res = args_res[arg] + assert np.allclose(arg_ref, arg_res) + + +def test_fusion_dynamic_producer(): + + def ref(A, B): + for i in range(10): + if B[i] < 0.5: + A[i] = 0.0 + for i in range(10): + B[i] = np.sin(A[i]) + + sdfg = dace.SDFG("fusion_dynamic_producer_sdfg") + state = sdfg.add_state(is_start_block=True) + for name in "AB": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + B_top, B_bottom, A = (state.add_access(name) for name in "BBA") + + state.add_mapped_tasklet( + "comp1", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("B[__i0]")}, + code="if __in1 < 0.5:\n\t__out = 0.0", + outputs={"__out": dace.Memlet("A[__i0]", dynamic=True)}, + input_nodes={B_top}, + output_nodes={A}, + external_edges=True, + ) + state.add_mapped_tasklet( + "comp2", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("A[__i0]")}, + code="__out = math.sin(__in1)", + outputs={"__out": dace.Memlet("B[__i0]")}, + input_nodes={A}, + output_nodes={B_bottom}, + external_edges=True, + ) + sdfg.validate() + + # In case dynamic Memlets should be handled, we specify `unspecific`, i.e. + # only validation tests are done. However, we run a verification step to see + # if the transformation did the right thing. + apply_fusion(sdfg, unspecific=True) + + args_ref = { + 'A': np.array(np.random.rand(11), dtype=np.float64, copy=True), + 'B': np.array(np.random.rand(11), dtype=np.float64, copy=True), + } + args_res = copy.deepcopy(args_ref) + + ref(**args_ref) + csdfg = sdfg.compile() + csdfg(**args_res) + for arg in args_ref.keys(): + arg_ref = args_ref[arg] + arg_res = args_res[arg] + assert np.allclose(arg_ref, arg_res) + + +def test_fusion_intrinsic_memlet_direction(): + + def ref(A, B): + T = A + 10.0 + B[:] = np.sin(T) + + sdfg = dace.SDFG("fusion_dynamic_producer_sdfg") + state = sdfg.add_state(is_start_block=True) + + for name in "ATB": + sdfg.add_array( + name, + shape=(10, 11, 12), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["T"].transient = True + + for num in "12": + sdfg.add_scalar( + "t" + num, + dtype=dace.float64, + transient=True, + ) + + A, T, B, t1, t2 = (state.add_access(name) for name in ["A", "T", "B", "t1", "t2"]) + + tsklt1, me1, mx1 = state.add_mapped_tasklet( + "comp1", + map_ranges={ + "__i1": "0:10", + "__i2": "0:11", + "__i3": "0:12", + }, + inputs={"__in1": dace.Memlet("A[__i1, __i2, __i3]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("T[__i1, __i2, __i3]")}, + input_nodes={A}, + output_nodes={T}, + external_edges=True, + ) + + tsklt2, me2, mx2 = state.add_mapped_tasklet( + "comp2", + map_ranges={ + "__i1": "0:10", + "__i2": "0:11", + "__i3": "0:12", + }, + inputs={"__in1": dace.Memlet("T[__i1, __i2, __i3]")}, + code="__out = math.sin(__in1)", + outputs={"__out": dace.Memlet("B[__i1, __i2, __i3]")}, + input_nodes={T}, + output_nodes={B}, + external_edges=True, + ) + + for me in [me1, me2]: + dace.transformation.dataflow.MapExpansion.apply_to( + sdfg, + options={"inner_schedule": dace.ScheduleType.Default}, + map_entry=me, + ) + + # Now add a transient scalar at the output of `tsklt1`. + tsklt1_oedge = next(iter(state.out_edges(tsklt1))) + me1_inner = tsklt1_oedge.dst + state.add_edge( + tsklt1, "__out", + t1, None, + dace.Memlet("t1[0]"), + ) + state.add_edge( + t1, None, + me1_inner, tsklt1_oedge.dst_conn, + dace.Memlet("t1[0] -> [__i1, __i2, __i3]"), + ) + state.remove_edge(tsklt1_oedge) + tsklt1_oedge = None + + # Now add a transient scalar in the front of `tsklt2`. + tsklt2_iedge = next(iter(state.in_edges(tsklt2))) + me2_inner = tsklt2_iedge.src + state.add_edge( + me2_inner, tsklt2_iedge.src_conn, + t2, None, + dace.Memlet("t2[0] -> [__i1, __i2, __i3]"), + ) + state.add_edge( + t2, None, + tsklt2, "__in1", + dace.Memlet("t2[0]"), + ) + state.remove_edge(tsklt2_iedge) + tsklt2_iedge = None + sdfg.validate() + + # By Specifying `apply_once` we only perform one fusion, which will eliminate `T`. + # This is not efficient, we do this to make sure that the update of the Memlets + # has worked. + apply_fusion(sdfg, apply_once=True) + + for edge in state.edges(): + # There should be no edge, that references `T`. + assert edge.data.data != "T" + + # If an edge is connected to `t2` or `t1` then its data should refer to it. + # no other Memlet shall refer to them. + for t in [t1, t2]: + if edge.src is t or edge.dst is t: + assert edge.data.data == t.data + else: + assert edge.data.data != t.data + + args_ref = { + 'A': np.array(np.random.rand(10, 11, 12), dtype=np.float64, copy=True), + 'B': np.array(np.random.rand(10, 11, 12), dtype=np.float64, copy=True), + } + args_res = copy.deepcopy(args_ref) + + ref(**args_ref) + sdfg(**args_res) + for arg in args_ref.keys(): + arg_ref = args_ref[arg] + arg_res = args_res[arg] + assert np.allclose(arg_ref, arg_res) + + +def _make_possible_cycle_if_fuesed_sdfg() -> Tuple[dace.SDFG, nodes.MapExit, nodes.AccessNode, nodes.MapEntry]: + """Generate an SDFG that if two maps would be fused a cycle would be created. + + Essentially tests if the MapFusion detects this special case. + """ + sdfg = dace.SDFG("possible_cycle_if_fuesed_sdfg") + state = sdfg.add_state(is_start_block=True) + + names = ["A", "B", "T", "U", "V"] + for name in names: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=True, + ) + sdfg.arrays["A"].transient = False + sdfg.arrays["B"].transient = False + + A, B, T, U, V = (state.add_access(name) for name in names) + + _, _, first_map_exit = state.add_mapped_tasklet( + "map1", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("A[__i]")}, + code="__out1 = __in + 10\n__out2 = __in - 10", + outputs={ + "__out1": dace.Memlet("T[__i]"), + "__out2": dace.Memlet("U[__i]"), + }, + input_nodes={A}, + output_nodes={T, U}, + external_edges=True, + ) + + state.add_mapped_tasklet( + "map2", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("U[__i]")}, + code="__out = math.sin(__in)", + outputs={"__out": dace.Memlet("V[__i]")}, + input_nodes={U}, + output_nodes={V}, + external_edges=True, + ) + + _, second_map_entry, _ = state.add_mapped_tasklet( + "map3", + map_ranges={"__i": "0:10"}, + inputs={ + "__in1": dace.Memlet("T[__i]"), + "__in2": dace.Memlet("V[__i]"), + }, + code="__out = __in1 + __in2", + outputs={"__out": dace.Memlet("B[__i]")}, + input_nodes={T, V}, + output_nodes={B}, + external_edges=True, + ) + sdfg.validate() + + return sdfg, first_map_exit, T, second_map_entry + + +def test_possible_cycle_if_fuesed_sdfg(): + sdfg, first_map_exit, array, second_map_entry = _make_possible_cycle_if_fuesed_sdfg() + + would_transformation_apply = MapFusion.can_be_applied_to( + sdfg, + first_map_exit=first_map_exit, + array=array, + second_map_entry=second_map_entry, + ) + assert not would_transformation_apply + + if __name__ == '__main__': + test_fusion_intrinsic_memlet_direction() + test_fusion_dynamic_producer() + test_fusion_different_global_accesses() + test_fusion_multiple_consumers() + test_fusion_intermediate_different_access() + test_fusion_intermediate_different_access_mod_shape() + test_fusion_non_strict_dataflow_implicit_dependency() + test_fusion_strict_dataflow_pointwise() + test_fusion_strict_dataflow_not_pointwise() + test_fusion_dataflow_intermediate() + test_fusion_dataflow_intermediate_2() + test_fusion_dataflow_intermediate_downstream() + test_indirect_accesses() + test_fusion_shared() + test_fusion_with_transient() + test_fusion_rename() test_fusion_simple() test_multiple_fusions() test_fusion_chain() - test_fusion_with_transient() test_fusion_with_transient_scalar() test_fusion_with_inverted_indices() test_fusion_with_empty_memlet() test_fusion_with_nested_sdfg_0() + test_interstate_fusion() test_fusion_with_nested_sdfg_1() + test_fuse_indirect_accesses() + test_offset_correction_range_read() + test_offset_correction_scalar_read() + test_offset_correction_empty() + test_different_offsets() + test_inner_map_dependency() + test_inner_map_dependency_resolved() + test_possible_cycle_if_fuesed_sdfg() diff --git a/tests/transformations/warp_tiling_test.py b/tests/transformations/warp_tiling_test.py index 7c75d08878..4ca424a1c1 100644 --- a/tests/transformations/warp_tiling_test.py +++ b/tests/transformations/warp_tiling_test.py @@ -38,19 +38,27 @@ def test_warp_softmax(vector_length=1): sdfg = softmax_fwd.to_sdfg(simplify=True) # Apply transformations - sdfg.apply_transformations_repeated(ReduceExpansion) + sdfg.apply_transformations_repeated(ReduceExpansion, validate_all=True) MultiExpansion.apply_to(sdfg, sdfg.node(0).nodes()) SubgraphFusion.apply_to(sdfg, sdfg.node(0).nodes()) sdfg.expand_library_nodes() sdfg.simplify() - sdfg.apply_transformations_repeated([TrivialMapElimination, MapFusion]) - sdfg.apply_transformations(GPUTransformSDFG) + sdfg.apply_transformations_repeated([TrivialMapElimination, MapFusion], validate_all=True) + sdfg.apply_transformations(GPUTransformSDFG, validate_all=True) assert sdfg.apply_transformations(WarpTiling) == 1 - sdfg.apply_transformations_repeated([HoistState, InlineSDFG, StateFusion]) - sdfg.apply_transformations_repeated([TrivialMapElimination, MapFusion]) + sdfg.apply_transformations_repeated([HoistState, InlineSDFG, StateFusion], validate_all=True) + sdfg.apply_transformations_repeated([TrivialMapElimination, MapFusion], validate_all=True) if vector_length != 1: sdfg.apply_transformations_repeated( - Vectorization, dict(vector_len=vector_length, preamble=False, postamble=False, strided_map=False)) + Vectorization, + dict( + vector_len=vector_length, + preamble=False, + postamble=False, + strided_map=False + ), + validate_all=True + ) sdfg.specialize(dict(dn1=2, dn2=16, dn3=128, dr=128)) # Check validity