From aa433fe58b40a0a939a4d7a37e64c6e41f877850 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 22 Aug 2024 15:50:26 +0200 Subject: [PATCH 001/115] Started with a first version of the map fusion stuff. --- dace/transformation/dataflow/__init__.py | 2 +- dace/transformation/dataflow/map_fusion.py | 539 +------------ .../dataflow/map_fusion_helper.py | 740 ++++++++++++++++++ .../dataflow/map_fusion_parallel.py | 117 +++ .../dataflow/map_fusion_serial.py | 471 +++++++++++ 5 files changed, 1334 insertions(+), 535 deletions(-) create mode 100644 dace/transformation/dataflow/map_fusion_helper.py create mode 100644 dace/transformation/dataflow/map_fusion_parallel.py create mode 100644 dace/transformation/dataflow/map_fusion_serial.py diff --git a/dace/transformation/dataflow/__init__.py b/dace/transformation/dataflow/__init__.py index db4c928481..dbd3838d9f 100644 --- a/dace/transformation/dataflow/__init__.py +++ b/dace/transformation/dataflow/__init__.py @@ -8,7 +8,7 @@ from .map_for_loop import MapToForLoop, MapToForLoopRegion from .map_interchange import MapInterchange from .map_dim_shuffle import MapDimShuffle -from .map_fusion import MapFusion +from .map_fusion import MapFusion, ParallelMapFusion, SerialMapFusion from .map_fission import MapFission from .map_unroll import MapUnroll from .trivial_map_elimination import TrivialMapElimination diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index a6762d45c4..3735d3e7dc 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -1,537 +1,8 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -""" This module contains classes that implement the map fusion transformation. -""" +"""Make all map fusion transformations available.""" -from copy import deepcopy as dcpy -from dace.sdfg.sdfg import SDFG -from dace.sdfg.state import SDFGState -from dace import data, dtypes, symbolic, subsets -from dace.sdfg import nodes -from dace.memlet import Memlet -from dace.sdfg import replace -from dace.sdfg import utils as sdutil -from dace.transformation import transformation -from typing import List, Union -import networkx as nx +from .map_fusion_serial import SerialMapFusion +from .map_fusion_parallel import ParallelMapFusion - -class MapFusion(transformation.SingleStateTransformation): - """ Implements the MapFusion transformation. - It wil check for all patterns MapExit -> AccessNode -> MapEntry, and - based on the following rules, fuse them and remove the transient in - between. There are several possibilities of what it does to this - transient in between. - - Essentially, if there is some other place in the - sdfg where it is required, or if it is not a transient, then it will - not be removed. In such a case, it will be linked to the MapExit node - of the new fused map. - - Rules for fusing maps: - 0. The map range of the second map should be a permutation of the - first map range. - 1. Each of the access nodes that are adjacent to the first map exit - should have an edge to the second map entry. If it doesn't, then the - second map entry should not be reachable from this access node. - 2. Any node that has a wcr from the first map exit should not be - adjacent to the second map entry. - 3. Access pattern for the access nodes in the second map should be - the same permutation of the map parameters as the map ranges of the - two maps. Alternatively, this access node should not be adjacent to - the first map entry. - """ - first_map_exit = transformation.PatternNode(nodes.ExitNode) - array = transformation.PatternNode(nodes.AccessNode) - second_map_entry = transformation.PatternNode(nodes.EntryNode) - - @staticmethod - def annotates_memlets(): - return False - - @classmethod - def expressions(cls): - return [sdutil.node_path_graph(cls.first_map_exit, cls.array, cls.second_map_entry)] - - @staticmethod - def find_permutation(first_map: nodes.Map, second_map: nodes.Map) -> Union[List[int], None]: - """ Find permutation between two map ranges. - - :param first_map: First map. - :param second_map: Second map. - :return: None if no such permutation exists, otherwise a list of - indices L such that L[x]'th parameter of second map has the same range as x'th - parameter of the first map. - """ - result = [] - - if len(first_map.range) != len(second_map.range): - return None - - # Match map ranges with reduce ranges - for i, tmap_rng in enumerate(first_map.range): - found = False - for j, rng in enumerate(second_map.range): - if tmap_rng == rng and j not in result: - result.append(j) - found = True - break - if not found: - break - - # Ensure all map ranges matched - if len(result) != len(first_map.range): - return None - - return result - - def can_be_applied(self, graph, expr_index, sdfg: SDFG, permissive=False): - first_map_exit = self.first_map_exit - first_map_entry = graph.entry_node(first_map_exit) - second_map_entry = self.second_map_entry - second_map_exit = graph.exit_node(second_map_entry) - - for _in_e in graph.in_edges(first_map_exit): - if _in_e.data.wcr is not None: - for _out_e in graph.out_edges(second_map_entry): - if _out_e.data.data == _in_e.data.data: - # wcr is on a node that is used in the second map, quit - return False - # Check whether there is a pattern map -> access -> map. - intermediate_nodes = set() - intermediate_data = set() - for _, _, dst, _, _ in graph.out_edges(first_map_exit): - if isinstance(dst, nodes.AccessNode): - intermediate_nodes.add(dst) - intermediate_data.add(dst.data) - - # If array is used anywhere else in this state. - num_occurrences = len([n for n in sdfg.data_nodes() if n.data == dst.data]) - if num_occurrences > 1: - return False - else: - return False - # Check map ranges - perm = self.find_permutation(first_map_entry.map, second_map_entry.map) - if perm is None: - return False - - # Check if any intermediate transient is also going to another location - second_inodes = set(e.src for e in graph.in_edges(second_map_entry) if isinstance(e.src, nodes.AccessNode)) - transients_to_remove = intermediate_nodes & second_inodes - # if any(e.dst != second_map_entry for n in transients_to_remove - # for e in graph.out_edges(n)): - if any(graph.out_degree(n) > 1 for n in transients_to_remove): - return False - - # Create a dict that maps parameters of the first map to those of the - # second map. - params_dict = {} - for _index, _param in enumerate(second_map_entry.map.params): - params_dict[_param] = first_map_entry.map.params[perm[_index]] - - out_memlets = [e.data for e in graph.in_edges(first_map_exit)] - - # Check that input set of second map is provided by the output set - # of the first map, or other unrelated maps - for second_edge in graph.out_edges(second_map_entry): - # NOTE: We ignore edges that do not carry data (e.g., connecting a tasklet with no inputs to the MapEntry) - if second_edge.data.is_empty(): - continue - # Memlets that do not come from one of the intermediate arrays - if second_edge.data.data not in intermediate_data: - # however, if intermediate_data eventually leads to - # second_memlet.data, need to fail. - for _n in intermediate_nodes: - source_node = _n - destination_node = graph.memlet_path(second_edge)[0].src - # NOTE: Assumes graph has networkx version - if destination_node in nx.descendants(graph._nx, source_node): - return False - continue - - provided = False - - # Compute second subset with respect to first subset's symbols - sbs_permuted = dcpy(second_edge.data.subset) - if sbs_permuted: - # Create intermediate dicts to avoid conflicts, such as {i:j, j:i} - symbolic.safe_replace(params_dict, lambda m: sbs_permuted.replace(m)) - - for first_memlet in out_memlets: - if first_memlet.data != second_edge.data.data: - continue - - # If there is a covered subset, it is provided - if first_memlet.subset.covers(sbs_permuted): - provided = True - break - - # If none of the output memlets of the first map provide the info, - # fail. - if provided is False: - return False - - # Checking for stencil pattern and common input/output data - # (after fusing the maps) - first_map_inputnodes = { - e.src: e.src.data - for e in graph.in_edges(first_map_entry) if isinstance(e.src, nodes.AccessNode) - } - input_views = set() - viewed_inputnodes = dict() - for n in first_map_inputnodes.keys(): - if isinstance(n.desc(sdfg), data.View): - input_views.add(n) - for v in input_views: - del first_map_inputnodes[v] - e = sdutil.get_view_edge(graph, v) - if e: - while not isinstance(e.src, nodes.AccessNode): - e = graph.memlet_path(e)[0] - first_map_inputnodes[e.src] = e.src.data - viewed_inputnodes[e.src.data] = v - second_map_outputnodes = { - e.dst: e.dst.data - for e in graph.out_edges(second_map_exit) if isinstance(e.dst, nodes.AccessNode) - } - output_views = set() - viewed_outputnodes = dict() - for n in second_map_outputnodes: - if isinstance(n.desc(sdfg), data.View): - output_views.add(n) - for v in output_views: - del second_map_outputnodes[v] - e = sdutil.get_view_edge(graph, v) - if e: - while not isinstance(e.dst, nodes.AccessNode): - e = graph.memlet_path(e)[-1] - second_map_outputnodes[e.dst] = e.dst.data - viewed_outputnodes[e.dst.data] = v - common_data = set(first_map_inputnodes.values()).intersection(set(second_map_outputnodes.values())) - if common_data: - input_data = [viewed_inputnodes[d].data if d in viewed_inputnodes.keys() else d for d in common_data] - input_accesses = [ - graph.memlet_path(e)[-1].data.src_subset for e in graph.out_edges(first_map_entry) - if e.data.data in input_data - ] - if len(input_accesses) > 1: - for i, a in enumerate(input_accesses[:-1]): - for b in input_accesses[i + 1:]: - if isinstance(a, subsets.Indices): - c = subsets.Range.from_indices(a) - c.offset(b, negative=True) - else: - c = a.offset_new(b, negative=True) - for r in c: - if r != (0, 0, 1): - return False - - output_data = [viewed_outputnodes[d].data if d in viewed_outputnodes.keys() else d for d in common_data] - output_accesses = [ - graph.memlet_path(e)[0].data.dst_subset for e in graph.in_edges(second_map_exit) - if e.data.data in output_data - ] - - # Compute output accesses with respect to first map's symbols - oacc_permuted = [dcpy(a) for a in output_accesses] - for a in oacc_permuted: - # Create intermediate dicts to avoid conflicts, such as {i:j, j:i} - symbolic.safe_replace(params_dict, lambda m: a.replace(m)) - - a = input_accesses[0] - for b in oacc_permuted: - if isinstance(a, subsets.Indices): - c = subsets.Range.from_indices(a) - c.offset(b, negative=True) - else: - c = a.offset_new(b, negative=True) - for r in c: - if r != (0, 0, 1): - return False - - # Success - return True - - def apply(self, graph: SDFGState, sdfg: SDFG): - """ - This method applies the mapfusion transformation. - Other than the removal of the second map entry node (SME), and the first - map exit (FME) node, it has the following side effects: - - 1. Any transient adjacent to both FME and SME with degree = 2 will be removed. - The tasklets that use/produce it shall be connected directly with a - scalar/new transient (if the dataflow is more than a single scalar) - - 2. If this transient is adjacent to FME and SME and has other - uses, it will be adjacent to the new map exit post fusion. - Tasklet-> Tasklet edges will ALSO be added as mentioned above. - - 3. If an access node is adjacent to FME but not SME, it will be - adjacent to new map exit post fusion. - - 4. If an access node is adjacent to SME but not FME, it will be - adjacent to the new map entry node post fusion. - - """ - first_exit = self.first_map_exit - first_entry = graph.entry_node(first_exit) - second_entry = self.second_map_entry - second_exit = graph.exit_node(second_entry) - - intermediate_nodes = set() - for _, _, dst, _, _ in graph.out_edges(first_exit): - intermediate_nodes.add(dst) - assert isinstance(dst, nodes.AccessNode) - - # Check if an access node refers to non transient memory, or transient - # is used at another location (cannot erase) - do_not_erase = set() - for node in intermediate_nodes: - if sdfg.arrays[node.data].transient is False: - do_not_erase.add(node) - else: - for edge in graph.in_edges(node): - if edge.src != first_exit: - do_not_erase.add(node) - break - else: - for edge in graph.out_edges(node): - if edge.dst != second_entry: - do_not_erase.add(node) - break - - # Find permutation between first and second scopes - perm = self.find_permutation(first_entry.map, second_entry.map) - params_dict = {} - for index, param in enumerate(first_entry.map.params): - params_dict[param] = second_entry.map.params[perm[index]] - - # Replaces (in memlets and tasklet) the second scope map - # indices with the permuted first map indices. - # This works in two passes to avoid problems when e.g., exchanging two - # parameters (instead of replacing (j,i) and (i,j) to (j,j) and then - # i,i). - second_scope = graph.scope_subgraph(second_entry) - for firstp, secondp in params_dict.items(): - if firstp != secondp: - replace(second_scope, secondp, '__' + secondp + '_fused') - for firstp, secondp in params_dict.items(): - if firstp != secondp: - replace(second_scope, '__' + secondp + '_fused', firstp) - - # Isolate First exit node - ############################ - edges_to_remove = set() - nodes_to_remove = set() - for edge in graph.in_edges(first_exit): - tree = graph.memlet_tree(edge) - access_node = tree.root().edge.dst - if access_node not in do_not_erase: - out_edges = [e for e in graph.out_edges(access_node) if e.dst == second_entry] - # In this transformation, there can only be one edge to the - # second map - assert len(out_edges) == 1 - - # Get source connector to the second map - connector = out_edges[0].dst_conn[3:] - - new_dsts = [] - # Look at the second map entry out-edges to get the new - # destinations - for e in graph.out_edges(second_entry): - if e.src_conn and e.src_conn[4:] == connector: - new_dsts.append(e) - if not new_dsts: # Access node is not used in the second map - nodes_to_remove.add(access_node) - continue - - # Add a transient scalar/array - self.fuse_nodes(sdfg, graph, edge, new_dsts[0].dst, new_dsts[0].dst_conn, new_dsts[1:]) - - edges_to_remove.add(edge) - - # Remove transient node between the two maps - nodes_to_remove.add(access_node) - else: # The case where intermediate array node cannot be removed - # Node will become an output of the second map exit - out_e = tree.parent.edge - conn = second_exit.next_connector() - graph.add_edge( - second_exit, - 'OUT_' + conn, - out_e.dst, - out_e.dst_conn, - dcpy(out_e.data), - ) - second_exit.add_out_connector('OUT_' + conn) - - graph.add_edge(edge.src, edge.src_conn, second_exit, 'IN_' + conn, dcpy(edge.data)) - second_exit.add_in_connector('IN_' + conn) - - edges_to_remove.add(out_e) - edges_to_remove.add(edge) - - # If the second map needs this node, link the connector - # that generated this to the place where it is needed, with a - # temp transient/scalar for memlet to be generated - for out_e in graph.out_edges(second_entry): - second_memlet_path = graph.memlet_path(out_e) - source_node = second_memlet_path[0].src - if source_node == access_node: - self.fuse_nodes(sdfg, graph, edge, out_e.dst, out_e.dst_conn) - - ### - # First scope exit is isolated and can now be safely removed - for e in edges_to_remove: - graph.remove_edge(e) - graph.remove_nodes_from(nodes_to_remove) - graph.remove_node(first_exit) - - # Isolate second_entry node - ########################### - for edge in graph.in_edges(second_entry): - tree = graph.memlet_tree(edge) - access_node = tree.root().edge.src - if access_node in intermediate_nodes: - # Already handled above, can be safely removed - graph.remove_edge(edge) - continue - - # This is an external input to the second map which will now go - # through the first map. - conn = first_entry.next_connector() - graph.add_edge(edge.src, edge.src_conn, first_entry, 'IN_' + conn, dcpy(edge.data)) - first_entry.add_in_connector('IN_' + conn) - graph.remove_edge(edge) - for out_enode in tree.children: - out_e = out_enode.edge - graph.add_edge( - first_entry, - 'OUT_' + conn, - out_e.dst, - out_e.dst_conn, - dcpy(out_e.data), - ) - graph.remove_edge(out_e) - first_entry.add_out_connector('OUT_' + conn) - - # NOTE: Check the second MapEntry for output edges with empty memlets - for edge in graph.out_edges(second_entry): - if edge.data.is_empty(): - graph.remove_edge(edge) - graph.add_edge(first_entry, edge.src_conn, edge.dst, edge.dst_conn, edge.data) - - ### - # Second node is isolated and can now be safely removed - graph.remove_node(second_entry) - - # Fix scope exit to point to the right map - second_exit.map = first_entry.map - - def fuse_nodes(self, sdfg: SDFG, graph: SDFGState, edge, new_dst, new_dst_conn, other_edges=None): - """ Fuses two nodes via memlets and possibly transient arrays. """ - other_edges = other_edges or [] - memlet_path = graph.memlet_path(edge) - access_node = memlet_path[-1].dst - - local_name = "__s%d_n%d%s_n%d%s" % ( - self.state_id, - graph.node_id(edge.src), - edge.src_conn, - graph.node_id(edge.dst), - edge.dst_conn, - ) - # Add intermediate memory between subgraphs. - # If a scalar, uses direct connection. If an array, adds a transient node. - # NOTE: If any of the src/dst nodes is a nested SDFG, treat it as an array. - is_scalar = edge.data.subset.num_elements() == 1 - accesses = ( - [graph.memlet_path(e1)[0].src for e0 in graph.in_edges(access_node) for e1 in graph.memlet_tree(e0)] + - [graph.memlet_path(e1)[-1].dst for e0 in graph.out_edges(access_node) for e1 in graph.memlet_tree(e0)]) - if any(isinstance(a, nodes.NestedSDFG) for a in accesses): - is_scalar = False - if is_scalar: - local_name, _ = sdfg.add_scalar( - local_name, - dtype=access_node.desc(graph).dtype, - transient=True, - storage=dtypes.StorageType.Register, - find_new_name=True, - ) - edge.data.data = local_name - edge.data.subset = "0" - - # If source of edge leads to multiple destinations, redirect all through an access node. - out_edges = list(graph.out_edges_by_connector(edge.src, edge.src_conn)) - if len(out_edges) > 1: - local_node = graph.add_access(local_name) - src_connector = None - - # Add edge that leads to transient node - graph.add_edge(edge.src, edge.src_conn, local_node, None, dcpy(edge.data)) - - for other_edge in out_edges: - if other_edge is not edge: - graph.remove_edge(other_edge) - mem = Memlet(data=local_name, other_subset=other_edge.data.dst_subset) - graph.add_edge(local_node, src_connector, other_edge.dst, other_edge.dst_conn, mem) - else: - local_node = edge.src - src_connector = edge.src_conn - - # update edge data in case source or destination is a scalar access node - test_data = [node.data for node in (edge.src, edge.dst) if isinstance(node, nodes.AccessNode)] - for new_data in test_data: - if isinstance(sdfg.arrays[new_data], data.Scalar): - edge.data.data = new_data - - # If destination of edge leads to multiple destinations, redirect all through an access node. - if other_edges: - # NOTE: If a new local node was already created, reuse it. - if local_node == edge.src: - local_node_out = graph.add_access(local_name) - connector_out = None - else: - local_node_out = local_node - connector_out = src_connector - graph.add_edge(local_node, src_connector, local_node_out, connector_out, - Memlet.from_array(local_name, sdfg.arrays[local_name])) - graph.add_edge(local_node_out, connector_out, new_dst, new_dst_conn, dcpy(edge.data)) - for e in other_edges: - graph.add_edge(local_node_out, connector_out, e.dst, e.dst_conn, dcpy(edge.data)) - else: - # Add edge that leads to the second node - graph.add_edge(local_node, src_connector, new_dst, new_dst_conn, dcpy(edge.data)) - - else: - local_name, _ = sdfg.add_transient(local_name, - symbolic.overapproximate(edge.data.subset.size()), - dtype=access_node.desc(graph).dtype, - find_new_name=True) - old_edge = dcpy(edge) - local_node = graph.add_access(local_name) - src_connector = None - edge.data.data = local_name - edge.data.subset = ",".join(["0:" + str(s) for s in edge.data.subset.size()]) - # Add edge that leads to transient node - graph.add_edge( - edge.src, - edge.src_conn, - local_node, - None, - dcpy(edge.data), - ) - - # Add edge that leads to the second node - graph.add_edge(local_node, src_connector, new_dst, new_dst_conn, dcpy(edge.data)) - - for e in other_edges: - graph.add_edge(local_node, src_connector, e.dst, e.dst_conn, dcpy(edge.data)) - - # Modify data and memlets on all surrounding edges to match array - for neighbor in graph.all_edges(local_node): - for e in graph.memlet_tree(neighbor): - if e.data.data == local_name: - continue - e.data.data = local_name - e.data.subset.offset(old_edge.data.subset, negative=True) +# Compatibility with previous versions of DaCe and clients. +MapFusion = SerialMapFusion diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py new file mode 100644 index 0000000000..6493a63da0 --- /dev/null +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -0,0 +1,740 @@ +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. + +"""Implements Helper functionaliyies for map fusion""" + +import functools +import itertools +from typing import Any, Optional, Sequence, Union + +import dace +from dace import ( + data as dace_data, + properties as dace_properties, + subsets as dace_subsets, + transformation as dace_transformation, +) +from dace.sdfg import ( + SDFG, + SDFGState, + graph as dace_graph, + nodes as dace_nodes, + validation as dace_validation, +) +from dace.transformation import helpers as dace_helpers +from dace.transformation.dataflow import map_fusion_helper + +@dace_properties.make_properties +class MapFusionHelper(dace_transformation.SingleStateTransformation): + """Contains common part of the fusion for parallel and serial Map fusion. + + The transformation assumes that the SDFG obeys the principals outlined [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). + The main advantage of this structure is, that it is rather easy to determine + if a transient is used anywhere else. This check, performed by + `is_interstate_transient()`. It is further speeded up by cashing some computation, + thus such an object should not be used after interstate optimizations were applied + to the SDFG. + + Args: + only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + only_toplevel_maps: Only consider Maps that are at the top. + """ + + only_toplevel_maps = dace_properties.Property( + dtype=bool, + default=False, + allow_none=False, + desc="Only perform fusing if the Maps are in the top level.", + ) + only_inner_maps = dace_properties.Property( + dtype=bool, + default=False, + allow_none=False, + desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", + ) + shared_transients = dace_properties.DictProperty( + key_type=SDFG, + value_type=set[str], + default=None, + allow_none=True, + desc="Maps SDFGs to the set of array transients that can not be removed. " + "The variable acts as a cache, and is managed by 'is_interstate_transient()'.", + ) + + def __init__( + self, + only_inner_maps: Optional[bool] = None, + only_toplevel_maps: Optional[bool] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if only_toplevel_maps is not None: + self.only_toplevel_maps = bool(only_toplevel_maps) + if only_inner_maps is not None: + self.only_inner_maps = bool(only_inner_maps) + self.shared_transients = {} + + @classmethod + def expressions(cls) -> bool: + raise RuntimeError("The `_MapFusionHelper` is not a transformation on its own.") + + def can_be_fused( + self, + map_entry_1: dace_nodes.MapEntry, + map_entry_2: dace_nodes.MapEntry, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Performs basic checks if the maps can be fused. + + This function only checks constrains that are common between serial and + parallel map fusion process, which includes: + - The scope of the maps. + - The scheduling of the maps. + - The map parameters. + + However, for performance reasons, the function does not check if the node + decomposition exists. + + Args: + map_entry_1: The entry of the first (in serial case the top) map. + map_exit_2: The entry of the second (in serial case the bottom) map. + graph: The SDFGState in which the maps are located. + sdfg: The SDFG itself. + permissive: Currently unused. + """ + if self.only_inner_maps and self.only_toplevel_maps: + raise ValueError("You specified both `only_inner_maps` and `only_toplevel_maps`.") + + # Ensure that both have the same schedule + if map_entry_1.map.schedule != map_entry_2.map.schedule: + return False + + # Fusing is only possible if the two entries are in the same scope. + scope = graph.scope_dict() + if scope[map_entry_1] != scope[map_entry_2]: + return False + elif self.only_inner_maps: + if scope[map_entry_1] is None: + return False + elif self.only_toplevel_maps: + if scope[map_entry_1] is not None: + return False + # TODO(phimuell): Figuring out why this is here. + elif map_fusion_helper.is_nested_sdfg(sdfg): + return False + + # We will now check if there exists a "remapping" that we can use. + if not self.map_parameter_compatible( + map_1=map_entry_1.map, map_2=map_entry_2.map, state=graph, sdfg=sdfg + ): + return False + + return True + + @staticmethod + def relocate_nodes( + from_node: Union[dace_nodes.MapExit, dace_nodes.MapEntry], + to_node: Union[dace_nodes.MapExit, dace_nodes.MapEntry], + state: SDFGState, + sdfg: SDFG, + ) -> None: + """Move the connectors and edges from `from_node` to `to_nodes` node. + + This function will only rewire the edges, it does not remove the nodes + themselves. Furthermore, this function should be called twice per Map, + once for the entry and then for the exit. + While it does not remove the node themselves if guarantees that the + `from_node` has degree zero. + + Args: + from_node: Node from which the edges should be removed. + to_node: Node to which the edges should reconnect. + state: The state in which the operation happens. + sdfg: The SDFG that is modified. + """ + + # Now we relocate empty Memlets, from the `from_node` to the `to_node` + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.out_edges(from_node))): + dace_helpers.redirect_edge(state, empty_edge, new_src=to_node) + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.in_edges(from_node))): + dace_helpers.redirect_edge(state, empty_edge, new_dst=to_node) + + # We now ensure that there is only one empty Memlet from the `to_node` to any other node. + # Although it is allowed, we try to prevent it. + empty_targets: set[dace_nodes.Node] = set() + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.all_edges(to_node))): + if empty_edge.dst in empty_targets: + state.remove_edge(empty_edge) + empty_targets.add(empty_edge.dst) + + # We now determine which edges we have to migrate, for this we are looking at + # the incoming edges, because this allows us also to detect dynamic map ranges. + for edge_to_move in list(state.in_edges(from_node)): + assert isinstance(edge_to_move.dst_conn, str) + + if not edge_to_move.dst_conn.startswith("IN_"): + # Dynamic Map Range + # The connector name simply defines a variable name that is used, + # inside the Map scope to define a variable. We handle it directly. + dmr_symbol = edge_to_move.dst_conn + + # TODO(phimuell): Check if the symbol is really unused in the target scope. + if dmr_symbol in to_node.in_connectors: + raise NotImplementedError( + f"Tried to move the dynamic map range '{dmr_symbol}' from {from_node}'" + f" to '{to_node}', but the symbol is already known there, but the" + " renaming is not implemented." + ) + if not to_node.add_in_connector(dmr_symbol, force=False): + raise RuntimeError( # Might fail because of out connectors. + f"Failed to add the dynamic map range symbol '{dmr_symbol}' to '{to_node}'." + ) + dace_helpers.redirect_edge(state=state, edge=edge_to_move, new_dst=to_node) + from_node.remove_in_connector(dmr_symbol) + + # There is no other edge that we have to consider, so we just end here + continue + + # We have a Passthrough connection, i.e. there exists a matching `OUT_`. + old_conn = edge_to_move.dst_conn[3:] # The connection name without prefix + new_conn = to_node.next_connector(old_conn) + + to_node.add_in_connector("IN_" + new_conn) + for e in list(state.in_edges_by_connector(from_node, "IN_" + old_conn)): + dace_helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) + to_node.add_out_connector("OUT_" + new_conn) + for e in list(state.out_edges_by_connector(from_node, "OUT_" + old_conn)): + dace_helpers.redirect_edge( + state, e, new_src=to_node, new_src_conn="OUT_" + new_conn + ) + from_node.remove_in_connector("IN_" + old_conn) + from_node.remove_out_connector("OUT_" + old_conn) + + # Check if we succeeded. + if state.out_degree(from_node) != 0: + raise dace_validation.InvalidSDFGError( + f"Failed to relocate the outgoing edges from `{from_node}`, there are still `{state.out_edges(from_node)}`", + sdfg, + sdfg.node_id(state), + ) + if state.in_degree(from_node) != 0: + raise dace_validation.InvalidSDFGError( + f"Failed to relocate the incoming edges from `{from_node}`, there are still `{state.in_edges(from_node)}`", + sdfg, + sdfg.node_id(state), + ) + assert len(from_node.in_connectors) == 0 + assert len(from_node.out_connectors) == 0 + + @staticmethod + def map_parameter_compatible( + map_1: dace_nodes.Map, + map_2: dace_nodes.Map, + state: Union[SDFGState, SDFG], + sdfg: SDFG, + ) -> bool: + """Checks if the parameters of `map_1` are compatible with `map_2`. + + The check follows the following rules: + - The names of the map variables must be the same, i.e. no renaming + is performed. + - The ranges must be the same. + """ + range_1: dace_subsets.Range = map_1.range + params_1: Sequence[str] = map_1.params + range_2: dace_subsets.Range = map_2.range + params_2: Sequence[str] = map_2.params + + # The maps are only fuseable if we have an exact match in the parameter names + # this is because we do not do any renaming. This is in accordance with the + # rules. + if set(params_1) != set(params_2): + return False + + # Maps the name of a parameter to the dimension index + param_dim_map_1: dict[str, int] = {pname: i for i, pname in enumerate(params_1)} + param_dim_map_2: dict[str, int] = {pname: i for i, pname in enumerate(params_2)} + + # To fuse the two maps the ranges must have the same ranges + for pname in params_1: + idx_1 = param_dim_map_1[pname] + idx_2 = param_dim_map_2[pname] + # TODO(phimuell): do we need to call simplify? + if range_1[idx_1] != range_2[idx_2]: + return False + + return True + + def is_interstate_transient( + self, + transient: Union[str, dace_nodes.AccessNode], + sdfg: dace.SDFG, + state: dace.SDFGState, + ) -> bool: + """Tests if `transient` is an interstate transient, an can not be removed. + + Essentially this function checks if a transient might be needed in a + different state in the SDFG, because it transmit information from + one state to the other. + If only the name of the data container is passed the function will + first look for an corresponding access node. + + The set of these "interstate transients" is computed once per SDFG. + The result is then cached internally for later reuse. + + Args: + transient: The transient that should be checked. + sdfg: The SDFG containing the array. + state: If given the state the node is located in. + + Note: + This function build upon the structure of the SDFG that is outlined + in the HackMD document. + """ + + # According to [rule 6](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG) + # the set of such transients is partially given by all source access dace_nodes. + # Because of rule 3 we also include all scalars in this set, as an over + # approximation. Furthermore, because simplify might violate rule 3, + # we also include the sink dace_nodes. + + # See if we have already computed the set + if sdfg in self.shared_transients: + shared_sdfg_transients: set[str] = self.shared_transients[sdfg] + else: + # SDFG is not known so we have to compute the set. + shared_sdfg_transients = set() + for state_to_scan in sdfg.all_states(): + # TODO(phimuell): Use `all_nodes_recursive()` once it is available. + shared_sdfg_transients.update( + [ + node.data + for node in itertools.chain( + state_to_scan.source_nodes(), state_to_scan.sink_nodes() + ) + if isinstance(node, dace_nodes.AccessNode) + and sdfg.arrays[node.data].transient + ] + ) + self.shared_transients[sdfg] = shared_sdfg_transients + + if isinstance(transient, str): + name = transient + matching_access_nodes = [node for node in state.data_nodes() if node.data == name] + # Rule 8: There is only one access node per state for data. + assert len(matching_access_nodes) == 1 + transient = matching_access_nodes[0] + else: + assert isinstance(transient, dace_nodes.AccessNode) + name = transient.data + + desc: dace_data.Data = sdfg.arrays[name] + if not desc.transient: + return True + if isinstance(desc, dace_data.Scalar): + return True # Scalars can not be removed by fusion anyway. + + # Rule 8: If degree larger than one then it is used within the state. + if state.out_degree(transient) > 1: + return True + + # Now we check if it is used in a different state. + return name in shared_sdfg_transients + + def partition_first_outputs( + self, + state: SDFGState, + sdfg: SDFG, + map_exit_1: dace_nodes.MapExit, + map_entry_2: dace_nodes.MapEntry, + ) -> Union[ + tuple[ + set[dace_graph.MultiConnectorEdge[dace.Memlet]], + set[dace_graph.MultiConnectorEdge[dace.Memlet]], + set[dace_graph.MultiConnectorEdge[dace.Memlet]], + ], + None, + ]: + """Partition the output edges of `map_exit_1` for serial map fusion. + + The output edges of the first map are partitioned into three distinct sets, + defined as follows: + + - Pure Output Set `\mathbb{P}`: + These edges exits the first map and does not enter the second map. These + outputs will be simply be moved to the output of the second map. + - Exclusive Intermediate Set `\mathbb{E}`: + Edges in this set leaves the first map exit, enters an access node, from + where a Memlet then leads immediately to the second map. The memory + referenced by this access node is not used anywhere else, thus it can + be removed. + - Shared Intermediate Set `\mathbb{S}`: + These edges are very similar to the one in `\mathbb{E}` except that they + are used somewhere else, thus they can not be removed and have to be + recreated as output of the second map. + + Returns: + If such a decomposition exists the function will return the three sets + mentioned above in the same order. + In case the decomposition does not exist, i.e. the maps can not be fused + the function returns `None`. + + Args: + state: The in which the two maps are located. + sdfg: The full SDFG in whcih we operate. + map_exit_1: The exit node of the first map. + map_entry_2: The entry node of the second map. + """ + # The three outputs set. + pure_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() + exclusive_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() + shared_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() + + # Set of intermediate nodes that we have already processed. + processed_inter_nodes: set[dace_nodes.Node] = set() + + # Now scan all output edges of the first exit and classify them + for out_edge in state.out_edges(map_exit_1): + intermediate_node: dace_nodes.Node = out_edge.dst + + # We already processed the node, this should indicate that we should + # run simplify again, or we should start implementing this case. + if intermediate_node in processed_inter_nodes: + return None + processed_inter_nodes.add(intermediate_node) + + # Now let's look at all nodes that are downstream of the intermediate node. + # This, among other things, will tell us, how we have to handle this node. + downstream_nodes = map_fusion_helper.all_nodes_between( + graph=state, + begin=intermediate_node, + end=map_entry_2, + ) + + # If `downstream_nodes` is `None` this means that `map_entry_2` was never + # reached, thus `intermediate_node` does not enter the second map and + # the node is a pure output node. + if downstream_nodes is None: + pure_outputs.add(out_edge) + continue + + # The following tests are _after_ we have determined if we have a pure + # output node, because this allows us to handle more exotic pure node + # cases, as handling them is essentially rerouting an edge, whereas + # handling intermediate nodes is much more complicated. + + # Empty Memlets are only allowed if they are in `\mathbb{P}`, which + # is also the only place they really make sense (for a map exit). + # Thus if we now found an empty Memlet we reject it. + if out_edge.data.is_empty(): + return None + + # In case the intermediate has more than one entry, all must come from the + # first map, otherwise we can not fuse them. Currently we restrict this + # even further by saying that it has only one incoming Memlet. + if state.in_degree(intermediate_node) != 1: + return None + + # It can happen that multiple edges converges at the `IN_` connector + # of the first map exit, but there is only one edge leaving the exit. + # It is complicate to handle this, so for now we ignore it. + # TODO(phimuell): Handle this case properly. + inner_collector_edges = list( + state.in_edges_by_connector(intermediate_node, "IN_" + out_edge.src_conn[3:]) + ) + if len(inner_collector_edges) > 1: + return None + + # For us an intermediate node must always be an access node, because + # everything else we do not know how to handle. It is important that + # we do not test for non transient data here, because they can be + # handled has shared intermediates. + if not isinstance(intermediate_node, dace_nodes.AccessNode): + return None + intermediate_desc: dace_data.Data = intermediate_node.desc(sdfg) + if isinstance(intermediate_desc, dace_data.View): + return None + + # There are some restrictions we have on intermediate dace_nodes. The first one + # is that we do not allow WCR, this is because they need special handling + # which is currently not implement (the DaCe transformation has this + # restriction as well). The second one is that we can reduce the + # intermediate node and only feed a part into the second map, consider + # the case `b = a + 1; return b + 2`, where we have arrays. In this + # example only a single element must be available to the second map. + # However, this is hard to check so we will make a simplification. + # First, we will not check it at the producer, but at the consumer point. + # There we assume if the consumer does _not consume the whole_ + # intermediate array, then we can decompose the intermediate, by setting + # the map iteration index to zero and recover the shape, see + # implementation in the actual fusion routine. + # This is an assumption that is in most cases correct, but not always. + # However, doing it correctly is extremely complex. + for _, produce_edge in map_fusion_helper.find_upstream_producers(state, out_edge): + if produce_edge.data.wcr is not None: + return None + + if len(downstream_nodes) == 0: + # There is nothing between intermediate node and the entry of the + # second map, thus the edge belongs either in `\mathbb{S}` or + # `\mathbb{E}`. + + # This is a very special situation, i.e. the access node has many + # different connections to the second map entry, this is a special + # case that we do not handle. + # TODO(phimuell): Handle this case. + if state.out_degree(intermediate_node) != 1: + return None + + # Certain nodes need more than one element as input. As explained + # above, in this situation we assume that we can naturally decompose + # them iff the node does not consume that whole intermediate. + # Furthermore, it can not be a dynamic map range or a library node. + intermediate_size = functools.reduce(lambda a, b: a * b, intermediate_desc.shape) + consumers = map_fusion_helper.find_downstream_consumers(state=state, begin=intermediate_node) + for consumer_node, feed_edge in consumers: + # TODO(phimuell): Improve this approximation. + if ( + intermediate_size != 1 + ) and feed_edge.data.num_elements() == intermediate_size: + return None + if consumer_node is map_entry_2: # Dynamic map range. + return None + if isinstance(consumer_node, dace_nodes.LibraryNode): + # TODO(phimuell): Allow some library dace_nodes. + return None + + # Note that "remove" has a special meaning here, regardless of the + # output of the check function, from within the second map we remove + # the intermediate, it has more the meaning of "do we need to + # reconstruct it after the second map again?" + if self.is_interstate_transient(intermediate_node, sdfg, state): + shared_outputs.add(out_edge) + else: + exclusive_outputs.add(out_edge) + continue + + else: + # There is not only a single connection from the intermediate node to + # the second map, but the intermediate has more connections, thus + # the node might belong to the shared output. Of the many different + # possibilities, we only consider a single case: + # - The intermediate has a single connection to the second map, that + # fulfills the restriction outlined above. + # - All other connections have no connection to the second map. + found_second_entry = False + intermediate_size = functools.reduce(lambda a, b: a * b, intermediate_desc.shape) + for edge in state.out_edges(intermediate_node): + if edge.dst is map_entry_2: + if found_second_entry: # The second map was found again. + return None + found_second_entry = True + consumers = map_fusion_helper.find_downstream_consumers(state=state, begin=edge) + for consumer_node, feed_edge in consumers: + if feed_edge.data.num_elements() == intermediate_size: + return None + if consumer_node is map_entry_2: # Dynamic map range + return None + if isinstance(consumer_node, dace_nodes.LibraryNode): + # TODO(phimuell): Allow some library dace_nodes. + return None + else: + # Ensure that there is no path that leads to the second map. + after_intermdiate_node = map_fusion_helper.all_nodes_between( + graph=state, begin=edge.dst, end=map_entry_2 + ) + if after_intermdiate_node is not None: + return None + # If we are here, then we know that the node is a shared output + shared_outputs.add(out_edge) + continue + + assert exclusive_outputs or shared_outputs or pure_outputs + assert len(processed_inter_nodes) == sum( + len(x) for x in [pure_outputs, exclusive_outputs, shared_outputs] + ) + return (pure_outputs, exclusive_outputs, shared_outputs) + + +def is_nested_sdfg( + sdfg: Union[dace.SDFG, dace.SDFGState, dace_nodes.NestedSDFG], +) -> bool: + """Tests if `sdfg` is a NestedSDFG.""" + if isinstance(sdfg, dace.SDFGState): + sdfg = sdfg.parent + if isinstance(sdfg, dace_nodes.NestedSDFG): + return True + elif isinstance(sdfg, dace.SDFG): + if sdfg.parent_nsdfg_node is not None: + return True + return False + else: + raise TypeError(f"Does not know how to handle '{type(sdfg).__name__}'.") + + +def all_nodes_between( + graph: dace.SDFG | dace.SDFGState, + begin: dace_nodes.Node, + end: dace_nodes.Node, + reverse: bool = False, +) -> set[dace_nodes.Node] | None: + """Find all nodes that are reachable from `begin` but bound by `end`. + + Essentially the function starts a DFS at `begin`. If an edge is found that lead + to `end`, this edge is ignored. It will thus found any node that is reachable + from `begin` by a path that does not involve `end`. The returned set will + never contain `end` nor `begin`. In case `end` is never found the function + will return `None`. + + If `reverse` is set to `True` the function will start exploring at `end` and + follows the outgoing edges, i.e. the meaning of `end` and `begin` are swapped. + + Args: + graph: The graph to operate on. + begin: The start of the DFS. + end: The terminator node of the DFS. + reverse: Perform a backward DFS. + + Notes: + - The returned set will also contain the nodes of path that starts at + `begin` and ends at a node that is not `end`. + """ + + def next_nodes(node: dace_nodes.Node) -> Iterable[dace_nodes.Node]: + if reverse: + return (edge.src for edge in graph.in_edges(node)) + return (edge.dst for edge in graph.out_edges(node)) + + if reverse: + begin, end = end, begin + + to_visit: list[dace_nodes.Node] = [begin] + seen: set[dace_nodes.Node] = set() + found_end: bool = False + + while len(to_visit) > 0: + n: dace_nodes.Node = to_visit.pop() + if n == end: + found_end = True + continue + elif n in seen: + continue + seen.add(n) + to_visit.extend(next_nodes(n)) + + if not found_end: + return None + + seen.discard(begin) + return seen + + +def is_parallel( + graph: dace.SDFG | dace.SDFGState, + node1: dace_nodes.Node, + node2: dace_nodes.Node, +) -> bool: + """Tests if `node1` and `node2` are parallel. + + The nodes are parallel if `node2` can not be reached from `node1` and vice versa. + + Args: + graph: The graph to traverse. + node1: The first node to check. + node2: The second node to check. + """ + + # The `all_nodes_between()` function traverse the graph and returns `None` if + # `end` was not found. We have to call it twice, because we do not know + # which node is upstream if they are not parallel. + if all_nodes_between(graph=graph, begin=node1, end=node2) is not None: + return False + elif all_nodes_between(graph=graph, begin=node2, end=node1) is not None: + return False + return True + + +def find_downstream_consumers( + state: dace.SDFGState, + begin: dace_nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], + only_tasklets: bool = False, + reverse: bool = False, +) -> set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: + """Find all downstream connectors of `begin`. + + A consumer, in for this function, is any node that is neither an entry nor + an exit node. The function returns a set of pairs, the first element is the + node that acts as consumer and the second is the edge that leads to it. + By setting `only_tasklets` the nodes the function finds are only Tasklets. + + To find this set the function starts a search at `begin`, however, it is also + possible to pass an edge as `begin`. + If `reverse` is `True` the function essentially finds the producers that are + upstream. + + Args: + state: The state in which to look for the consumers. + begin: The initial node that from which the search starts. + only_tasklets: Return only Tasklets. + reverse: Follow the reverse direction. + """ + if isinstance(begin, dace_graph.MultiConnectorEdge): + to_visit: list[dace_graph.MultiConnectorEdge[dace.Memlet]] = [begin] + elif reverse: + to_visit = list(state.in_edges(begin)) + else: + to_visit = list(state.out_edges(begin)) + seen: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() + found: set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]] = set() + + while len(to_visit) != 0: + curr_edge: dace_graph.MultiConnectorEdge[dace.Memlet] = to_visit.pop() + next_node: dace_nodes.Node = curr_edge.src if reverse else curr_edge.dst + + if curr_edge in seen: + continue + seen.add(curr_edge) + + if isinstance(next_node, (dace_nodes.MapEntry, dace_nodes.MapExit)): + if reverse: + target_conn = curr_edge.src_conn[4:] + new_edges = state.in_edges_by_connector(curr_edge.src, "IN_" + target_conn) + else: + # In forward mode a Map entry could also mean the definition of a + # dynamic map range. + if (not curr_edge.dst_conn.startswith("IN_")) and isinstance( + next_node, dace_nodes.MapEntry + ): + # This edge defines a dynamic map range, which is a consumer + if not only_tasklets: + found.add((next_node, curr_edge)) + continue + target_conn = curr_edge.dst_conn[3:] + new_edges = state.out_edges_by_connector(curr_edge.dst, "OUT_" + target_conn) + to_visit.extend(new_edges) + del new_edges + else: + if only_tasklets and (not isinstance(next_node, dace_nodes.Tasklet)): + continue + found.add((next_node, curr_edge)) + + return found + + +def find_upstream_producers( + state: dace.SDFGState, + begin: dace_nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], + only_tasklets: bool = False, +) -> set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: + """Same as `find_downstream_consumers()` but with `reverse` set to `True`.""" + return find_downstream_consumers( + state=state, + begin=begin, + only_tasklets=only_tasklets, + reverse=True, + ) + + + + diff --git a/dace/transformation/dataflow/map_fusion_parallel.py b/dace/transformation/dataflow/map_fusion_parallel.py new file mode 100644 index 0000000000..179227d23f --- /dev/null +++ b/dace/transformation/dataflow/map_fusion_parallel.py @@ -0,0 +1,117 @@ +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. + +"""Implements the parallel map fusing transformation.""" + +from typing import Any, Optional, Union + +import dace +from dace import properties as dace_properties, transformation as dace_transformation +from dace.sdfg import SDFG, SDFGState, graph as dace_graph, nodes as dace_nodes + +from dace.transformation.dataflow import map_fusion_helper + +@dace_properties.make_properties +class ParallelMapFusion(map_fusion_helper.MapFusionHelper): + """The `ParallelMapFusion` transformation allows to merge two parallel maps together. + + The `SerialMapFusion` transformation is only able to handle maps that are sequential, + however, this transformation is able to fuse _any_ maps that are not sequential + and are in the same scope. + + Args: + only_if_common_ancestor: Only perform fusion if both Maps share at least one + node as direct ancestor. This will increase the locality of the merge. + only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + only_toplevel_maps: Only consider Maps that are at the top. + + Note: + This transformation only matches the entry nodes of the Map, but will also + modify the exit nodes of the Map. + """ + + map_entry1 = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + map_entry2 = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + + only_if_common_ancestor = dace_properties.Property( + dtype=bool, + default=False, + allow_none=False, + desc="Only perform fusing if the Maps share a node as parent.", + ) + + def __init__( + self, + only_if_common_ancestor: Optional[bool] = None, + **kwargs: Any, + ) -> None: + if only_if_common_ancestor is not None: + self.only_if_common_ancestor = only_if_common_ancestor + super().__init__(**kwargs) + + @classmethod + def expressions(cls) -> Any: + # This just matches _any_ two Maps inside a state. + state = dace_graph.OrderedMultiDiConnectorGraph() + state.add_nodes_from([cls.map_entry1, cls.map_entry2]) + return [state] + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """The transformation is applicable.""" + map_entry_1: dace_nodes.MapEntry = self.map_entry1 + map_entry_2: dace_nodes.MapEntry = self.map_entry2 + + # Check the structural properties of the maps, this will also ensure that + # the two maps are in the same scope. + if not self.can_be_fused( + map_entry_1=map_entry_1, + map_entry_2=map_entry_2, + graph=graph, + sdfg=sdfg, + permissive=permissive, + ): + return False + + # Since the match expression matches any twp Maps, we have to ensure that + # the maps are parallel. The `can_be_fused()` function already verified + # if they are in the same scope. + if not map_fusion_helper.is_parallel(graph=graph, node1=map_entry_1, node2=map_entry_2): + return False + + # Test if they have they share a node as direct ancestor. + if self.only_if_common_ancestor: + # This assumes that there is only one access node per data container in the state. + ancestors_1: set[dace_nodes.Node] = {e1.src for e1 in graph.in_edges(map_entry_1)} + if not any(e2.src in ancestors_1 for e2 in graph.in_edges(map_entry_2)): + return False + + return True + + def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: + """Performs the Map fusing. + + Essentially, the function relocate all edges from the nodes forming the second + Map to the corresponding nodes of the first Map. Afterwards the nodes of the + second Map are removed. + """ + assert self.map_parameter_compatible(self.map_entry1.map, self.map_entry2.map, graph, sdfg) + + map_entry_1: dace_nodes.MapEntry = self.map_entry1 + map_exit_1: dace_nodes.MapExit = graph.exit_node(map_entry_1) + map_entry_2: dace_nodes.MapEntry = self.map_entry2 + map_exit_2: dace_nodes.MapExit = graph.exit_node(map_entry_2) + + for to_node, from_node in zip((map_entry_1, map_exit_1), (map_entry_2, map_exit_2)): + self.relocate_nodes( + from_node=from_node, + to_node=to_node, + state=graph, + sdfg=sdfg, + ) + # The relocate function does not remove the node, so we must do it. + graph.remove_node(from_node) diff --git a/dace/transformation/dataflow/map_fusion_serial.py b/dace/transformation/dataflow/map_fusion_serial.py new file mode 100644 index 0000000000..0e7dbe014b --- /dev/null +++ b/dace/transformation/dataflow/map_fusion_serial.py @@ -0,0 +1,471 @@ +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. + +"""Implements the serial map fusing transformation.""" + +import copy +from typing import Any, Union + +import dace +from dace import ( + dtypes as dace_dtypes, + properties as dace_properties, + subsets as dace_subsets, + symbolic as dace_symbolic, + transformation as dace_transformation, +) +from dace.sdfg import SDFG, SDFGState, graph as dace_graph, nodes as dace_nodes + +from dace.transformation.dataflow import map_fusion_helper + + +@dace_properties.make_properties +class SerialMapFusion(map_fusion_helper.MapFusionHelper): + """Specialized replacement for the map fusion transformation that is provided by DaCe. + + As its name is indicating this transformation is only able to handle Maps that + are in sequence. Compared to the native DaCe transformation, this one is able + to handle more complex cases of connection between the maps. In that sense, it + is much more similar to DaCe's `SubgraphFusion` transformation. + + Things that are improved, compared to the native DaCe implementation: + - Nested Maps. + - Temporary arrays and the correct propagation of their Memlets. + - Top Maps that have multiple outputs. + + Conceptually this transformation removes the exit of the first or upper map + and the entry of the lower or second map and then rewrites the connections + appropriately. + + This transformation assumes that an SDFG obeys the structure that is outlined + [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). For that + reason it is not true replacement of the native DaCe transformation. + + Args: + only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + only_toplevel_maps: Only consider Maps that are at the top. + + Notes: + - This transformation modifies more nodes than it matches! + """ + + map_exit1 = dace_transformation.transformation.PatternNode(dace_nodes.MapExit) + access_node = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) + map_entry2 = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + + def __init__( + self, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + @classmethod + def expressions(cls) -> Any: + """Get the match expression. + + The transformation matches the exit node of the top Map that is connected to + an access node that again is connected to the entry node of the second Map. + An important note is, that the transformation operates not just on the + matched nodes, but more or less on anything that has an incoming connection + from the first Map or an outgoing connection to the second Map entry. + """ + return [dace.sdfg.utils.node_path_graph(cls.map_exit1, cls.access_node, cls.map_entry2)] + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Tests if the matched Maps can be merged. + + The two Maps are mergeable iff: + - The `can_be_fused()` of the base succeed, which checks some basic constraints. + - The decomposition exists and at least one of the intermediate sets + is not empty. + """ + assert isinstance(self.map_exit1, dace_nodes.MapExit) + assert isinstance(self.map_entry2, dace_nodes.MapEntry) + map_entry_1: dace_nodes.MapEntry = graph.entry_node(self.map_exit1) + map_entry_2: dace_nodes.MapEntry = self.map_entry2 + + # This essentially test the structural properties of the two Maps. + if not self.can_be_fused( + map_entry_1=map_entry_1, map_entry_2=map_entry_2, graph=graph, sdfg=sdfg + ): + return False + + # Two maps can be serially fused if the node decomposition exists and + # at least one of the intermediate output sets is not empty. The state + # of the pure outputs is irrelevant for serial map fusion. + output_partition = self.partition_first_outputs( + state=graph, + sdfg=sdfg, + map_exit_1=self.map_exit1, + map_entry_2=self.map_entry2, + ) + if output_partition is None: + return False + _, exclusive_outputs, shared_outputs = output_partition + if not (exclusive_outputs or shared_outputs): + return False + return True + + def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: + """Performs the serial Map fusing. + + The function first computes the map decomposition and then handles the + three sets. The pure outputs are handled by `relocate_nodes()` while + the two intermediate sets are handled by `handle_intermediate_set()`. + + By assumption we do not have to rename anything. + + Args: + graph: The SDFG state we are operating on. + sdfg: The SDFG we are operating on. + """ + # NOTE: `self.map_*` actually stores the ID of the node. + # once we start adding and removing nodes it seems that their ID changes. + # Thus we have to save them here, this is a known behaviour in DaCe. + assert isinstance(graph, dace.SDFGState) + assert isinstance(self.map_exit1, dace_nodes.MapExit) + assert isinstance(self.map_entry2, dace_nodes.MapEntry) + assert self.map_parameter_compatible(self.map_exit1.map, self.map_entry2.map, graph, sdfg) + + map_exit_1: dace_nodes.MapExit = self.map_exit1 + map_entry_2: dace_nodes.MapEntry = self.map_entry2 + map_exit_2: dace_nodes.MapExit = graph.exit_node(self.map_entry2) + map_entry_1: dace_nodes.MapEntry = graph.entry_node(self.map_exit1) + + output_partition = self.partition_first_outputs( + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + ) + assert output_partition is not None # Make MyPy happy. + pure_outputs, exclusive_outputs, shared_outputs = output_partition + + if len(exclusive_outputs) != 0: + self.handle_intermediate_set( + intermediate_outputs=exclusive_outputs, + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + map_exit_2=map_exit_2, + is_exclusive_set=True, + ) + if len(shared_outputs) != 0: + self.handle_intermediate_set( + intermediate_outputs=shared_outputs, + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + map_exit_2=map_exit_2, + is_exclusive_set=False, + ) + assert pure_outputs == set(graph.out_edges(map_exit_1)) + if len(pure_outputs) != 0: + self.relocate_nodes( + from_node=map_exit_1, + to_node=map_exit_2, + state=graph, + sdfg=sdfg, + ) + + # Above we have handled the input of the second map and moved them + # to the first map, now we must move the output of the first map + # to the second one, as this one is used. + self.relocate_nodes( + from_node=map_entry_2, + to_node=map_entry_1, + state=graph, + sdfg=sdfg, + ) + + for node_to_remove in [map_exit_1, map_entry_2]: + assert graph.degree(node_to_remove) == 0 + graph.remove_node(node_to_remove) + + # Now turn the second output node into the output node of the first Map. + map_exit_2.map = map_entry_1.map + + @staticmethod + def handle_intermediate_set( + intermediate_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]], + state: SDFGState, + sdfg: SDFG, + map_exit_1: dace_nodes.MapExit, + map_entry_2: dace_nodes.MapEntry, + map_exit_2: dace_nodes.MapExit, + is_exclusive_set: bool, + ) -> None: + """This function handles the intermediate sets. + + The function is able to handle both the shared and exclusive intermediate + output set, see `partition_first_outputs()`. The main difference is that + in exclusive mode the intermediate nodes will be fully removed from + the SDFG. While in shared mode the intermediate node will be preserved. + + Args: + intermediate_outputs: The set of outputs, that should be processed. + state: The state in which the map is processed. + sdfg: The SDFG that should be optimized. + map_exit_1: The exit of the first/top map. + map_entry_2: The entry of the second map. + map_exit_2: The exit of the second map. + is_exclusive_set: If `True` `intermediate_outputs` is the exclusive set. + + Notes: + Before the transformation the `state` does not have to be valid and + after this function has run the state is (most likely) invalid. + + Todo: + Rewrite using `MemletTree`. + """ + + # Essentially this function removes the AccessNode between the two maps. + # However, we still need some temporary memory that we can use, which is + # just much smaller, i.e. a scalar. But all Memlets inside the second map + # assumes that the intermediate memory has the bigger shape. + # To fix that we will create this replacement dict that will replace all + # occurrences of the iteration variables of the second map with zero. + # Note that this is still not enough as the dimensionality might be different. + memlet_repl: dict[str, int] = {str(param): 0 for param in map_entry_2.map.params} + + # Now we will iterate over all intermediate edges and process them. + # If not stated otherwise the comments assume that we run in exclusive mode. + for out_edge in intermediate_outputs: + # This is the intermediate node that, that we want to get rid of. + # In shared mode we want to recreate it after the second map. + inter_node: dace_nodes.AccessNode = out_edge.dst + inter_name = inter_node.data + inter_desc = inter_node.desc(sdfg) + inter_shape = inter_desc.shape + + # Now we will determine the shape of the new intermediate. This size of + # this temporary is given by the Memlet that goes into the first map exit. + pre_exit_edges = list( + state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:]) + ) + if len(pre_exit_edges) != 1: + raise NotImplementedError() + pre_exit_edge = pre_exit_edges[0] + new_inter_shape_raw = dace_symbolic.overapproximate(pre_exit_edge.data.subset.size()) + + # Over approximation will leave us with some unneeded size one dimensions. + # That are known to cause some troubles, so we will now remove them. + squeezed_dims: list[int] = [] # These are the dimensions we removed. + new_inter_shape: list[int] = [] # This is the final shape of the new intermediate. + for dim, (proposed_dim_size, full_dim_size) in enumerate( + zip(new_inter_shape_raw, inter_shape) + ): + # Order of checks is important! + if full_dim_size == 1: # Must be kept! + new_inter_shape.append(proposed_dim_size) + elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. + squeezed_dims.append(dim) + else: + new_inter_shape.append(proposed_dim_size) + + # This is the name of the new "intermediate" node that we will create. + # It will only have the shape `new_inter_shape` which is basically its + # output within one Map iteration. + # NOTE: The insertion process might generate a new name. + new_inter_name: str = f"__s{sdfg.node_id(state)}_n{state.node_id(out_edge.src)}{out_edge.src_conn}_n{state.node_id(out_edge.dst)}{out_edge.dst_conn}" + + # Now generate the intermediate data container. + if len(new_inter_shape) == 0: + assert pre_exit_edge.data.subset.num_elements() == 1 + is_scalar = True + new_inter_name, new_inter_desc = sdfg.add_scalar( + new_inter_name, + dtype=inter_desc.dtype, + transient=True, + storage=dace_dtypes.StorageType.Register, + find_new_name=True, + ) + + else: + assert (pre_exit_edge.data.subset.num_elements() > 1) or all( + x == 1 for x in new_inter_shape + ) + is_scalar = False + new_inter_name, new_inter_desc = sdfg.add_transient( + new_inter_name, + shape=new_inter_shape, + dtype=inter_desc.dtype, + find_new_name=True, + ) + new_inter_node: dace_nodes.AccessNode = state.add_access(new_inter_name) + + # New we will reroute the output Memlet, thus it will no longer pass + # through the Map exit but through the newly created intermediate. + # we will delete the previous edge later. + pre_exit_memlet: dace.Memlet = pre_exit_edge.data + new_pre_exit_memlet = copy.deepcopy(pre_exit_memlet) + + # We might operate on a different array, but the check below, ensures + # that we do not change the direction of the Memlet. + assert pre_exit_memlet.data == inter_name + new_pre_exit_memlet.data = new_inter_name + + # Now we have to modify the subset of the Memlet. + # Before the subset of the Memlet was dependent on the Map variables, + # however, this is no longer the case, as we removed them. This change + # has to be reflected in the Memlet. + # NOTE: Assert above ensures that the below is correct. + new_pre_exit_memlet.replace(memlet_repl) + if is_scalar: + new_pre_exit_memlet.subset = "0" + new_pre_exit_memlet.other_subset = None + else: + new_pre_exit_memlet.subset.pop(squeezed_dims) + + # Now we create the new edge between the producer and the new output + # (the new intermediate node). We will remove the old edge further down. + new_pre_exit_edge = state.add_edge( + pre_exit_edge.src, + pre_exit_edge.src_conn, + new_inter_node, + None, + new_pre_exit_memlet, + ) + + # We just have handled the last Memlet, but we must actually handle the + # whole producer side, i.e. the scope of the top Map. + for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children(): + producer_edge = producer_tree.edge + + # Ensure the correctness of the rerouting below. + # TODO(phimuell): Improve the code below to remove the check. + assert producer_edge.data.data == inter_name + + # Will not change the direction, because of test above! + producer_edge.data.data = new_inter_name + producer_edge.data.replace(memlet_repl) + if is_scalar: + producer_edge.data.dst_subset = "0" + elif producer_edge.data.dst_subset is not None: + producer_edge.data.dst_subset.pop(squeezed_dims) + + # Now after we have handled the input of the new intermediate node, + # we must handle its output. For this we have to "inject" the newly + # created intermediate into the second map. We do this by finding + # the input connectors on the map entry, such that we know where we + # have to reroute inside the Map. + # NOTE: Assumes that map (if connected is the direct neighbour). + conn_names: set[str] = set() + for inter_node_out_edge in state.out_edges(inter_node): + if inter_node_out_edge.dst == map_entry_2: + assert inter_node_out_edge.dst_conn.startswith("IN_") + conn_names.add(inter_node_out_edge.dst_conn) + else: + # If we found another target than the second map entry from the + # intermediate node it means that the node _must_ survive, + # i.e. we are not in exclusive mode. + assert not is_exclusive_set + + # Now we will reroute the connections inside the second map, i.e. + # instead of consuming the old intermediate node, they will now + # consume the new intermediate node. + for in_conn_name in conn_names: + out_conn_name = "OUT_" + in_conn_name[3:] + + for inner_edge in state.out_edges_by_connector(map_entry_2, out_conn_name): + assert inner_edge.data.data == inter_name # DIRECTION!! + + # The create the first Memlet to transmit information, within + # the second map, we do this again by copying and modifying + # the original Memlet. + # NOTE: Test above is important to ensure the direction of the + # Memlet and the correctness of the code below. + new_inner_memlet = copy.deepcopy(inner_edge.data) + new_inner_memlet.replace(memlet_repl) + new_inner_memlet.data = new_inter_name # Because of the assert above, this will not change the direction. + + # Now remove the old edge, that started the second map entry. + # Also add the new edge that started at the new intermediate. + state.remove_edge(inner_edge) + new_inner_edge = state.add_edge( + new_inter_node, + None, + inner_edge.dst, + inner_edge.dst_conn, + new_inner_memlet, + ) + + # Now we do subset modification to ensure that nothing failed. + if is_scalar: + new_inner_memlet.src_subset = "0" + elif new_inner_memlet.src_subset is not None: + new_inner_memlet.src_subset.pop(squeezed_dims) + + # Now clean the Memlets of that tree to use the new intermediate node. + for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children(): + consumer_edge = consumer_tree.edge + assert consumer_edge.data.data == inter_name + consumer_edge.data.data = new_inter_name + if is_scalar: + consumer_edge.data.src_subset = "0" + elif consumer_edge.data.subset is not None: + consumer_edge.data.subset.pop(squeezed_dims) + + # The edge that leaves the second map entry was already deleted. + # We will now delete the edges that brought the data. + for edge in list(state.in_edges_by_connector(map_entry_2, in_conn_name)): + assert edge.src == inter_node + state.remove_edge(edge) + map_entry_2.remove_in_connector(in_conn_name) + map_entry_2.remove_out_connector(out_conn_name) + + if is_exclusive_set: + # In exclusive mode the old intermediate node is no longer needed. + assert state.degree(inter_node) == 1 + state.remove_edge_and_connectors(out_edge) + state.remove_node(inter_node) + + state.remove_edge(pre_exit_edge) + map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) + map_exit_1.remove_out_connector(out_edge.src_conn) + del sdfg.arrays[inter_name] + + else: + # This is the shared mode, so we have to recreate the intermediate + # node, but this time it is at the exit of the second map. + state.remove_edge(pre_exit_edge) + map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) + + # This is the Memlet that goes from the map internal intermediate + # temporary node to the Map output. This will essentially restore + # or preserve the output for the intermediate node. It is important + # that we use the data that `preExitEdge` was used. + new_exit_memlet = copy.deepcopy(pre_exit_edge.data) + assert new_exit_memlet.data == inter_name + new_exit_memlet.subset = pre_exit_edge.data.dst_subset + new_exit_memlet.other_subset = ( + "0" if is_scalar else dace_subsets.Range.from_array(inter_desc) + ) + + new_pre_exit_conn = map_exit_2.next_connector() + state.add_edge( + new_inter_node, + None, + map_exit_2, + "IN_" + new_pre_exit_conn, + new_exit_memlet, + ) + state.add_edge( + map_exit_2, + "OUT_" + new_pre_exit_conn, + inter_node, + out_edge.dst_conn, + copy.deepcopy(out_edge.data), + ) + map_exit_2.add_in_connector("IN_" + new_pre_exit_conn) + map_exit_2.add_out_connector("OUT_" + new_pre_exit_conn) + + map_exit_1.remove_out_connector(out_edge.src_conn) + state.remove_edge(out_edge) From 71a88a1819d9d2ea265b7273bfd4309857ba1547 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 23 Aug 2024 08:03:43 +0200 Subject: [PATCH 002/115] Made some stylistic modification to teh code. Now using the 3.9 type hints. --- .../dataflow/map_fusion_helper.py | 173 ++++++++---------- .../dataflow/map_fusion_parallel.py | 30 +-- .../dataflow/map_fusion_serial.py | 66 +++---- 3 files changed, 126 insertions(+), 143 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index 6493a63da0..81fd9f054c 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -4,27 +4,16 @@ import functools import itertools -from typing import Any, Optional, Sequence, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Sequence, Tuple, Union import dace -from dace import ( - data as dace_data, - properties as dace_properties, - subsets as dace_subsets, - transformation as dace_transformation, -) -from dace.sdfg import ( - SDFG, - SDFGState, - graph as dace_graph, - nodes as dace_nodes, - validation as dace_validation, -) -from dace.transformation import helpers as dace_helpers +from dace import data, properties, subsets, transformation +from dace.sdfg import SDFG, SDFGState, graph, nodes, validation +from dace.transformation import helpers from dace.transformation.dataflow import map_fusion_helper -@dace_properties.make_properties -class MapFusionHelper(dace_transformation.SingleStateTransformation): +@properties.make_properties +class MapFusionHelper(transformation.SingleStateTransformation): """Contains common part of the fusion for parallel and serial Map fusion. The transformation assumes that the SDFG obeys the principals outlined [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). @@ -39,21 +28,21 @@ class MapFusionHelper(dace_transformation.SingleStateTransformation): only_toplevel_maps: Only consider Maps that are at the top. """ - only_toplevel_maps = dace_properties.Property( + only_toplevel_maps = properties.Property( dtype=bool, default=False, allow_none=False, desc="Only perform fusing if the Maps are in the top level.", ) - only_inner_maps = dace_properties.Property( + only_inner_maps = properties.Property( dtype=bool, default=False, allow_none=False, desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", ) - shared_transients = dace_properties.DictProperty( + shared_transients = properties.DictProperty( key_type=SDFG, - value_type=set[str], + value_type=set, #[str] default=None, allow_none=True, desc="Maps SDFGs to the set of array transients that can not be removed. " @@ -79,8 +68,8 @@ def expressions(cls) -> bool: def can_be_fused( self, - map_entry_1: dace_nodes.MapEntry, - map_entry_2: dace_nodes.MapEntry, + map_entry_1: nodes.MapEntry, + map_entry_2: nodes.MapEntry, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG, permissive: bool = False, @@ -134,8 +123,8 @@ def can_be_fused( @staticmethod def relocate_nodes( - from_node: Union[dace_nodes.MapExit, dace_nodes.MapEntry], - to_node: Union[dace_nodes.MapExit, dace_nodes.MapEntry], + from_node: Union[nodes.MapExit, nodes.MapEntry], + to_node: Union[nodes.MapExit, nodes.MapEntry], state: SDFGState, sdfg: SDFG, ) -> None: @@ -156,13 +145,13 @@ def relocate_nodes( # Now we relocate empty Memlets, from the `from_node` to the `to_node` for empty_edge in list(filter(lambda e: e.data.is_empty(), state.out_edges(from_node))): - dace_helpers.redirect_edge(state, empty_edge, new_src=to_node) + helpers.redirect_edge(state, empty_edge, new_src=to_node) for empty_edge in list(filter(lambda e: e.data.is_empty(), state.in_edges(from_node))): - dace_helpers.redirect_edge(state, empty_edge, new_dst=to_node) + helpers.redirect_edge(state, empty_edge, new_dst=to_node) # We now ensure that there is only one empty Memlet from the `to_node` to any other node. # Although it is allowed, we try to prevent it. - empty_targets: set[dace_nodes.Node] = set() + empty_targets: Set[nodes.Node] = set() for empty_edge in list(filter(lambda e: e.data.is_empty(), state.all_edges(to_node))): if empty_edge.dst in empty_targets: state.remove_edge(empty_edge) @@ -190,7 +179,7 @@ def relocate_nodes( raise RuntimeError( # Might fail because of out connectors. f"Failed to add the dynamic map range symbol '{dmr_symbol}' to '{to_node}'." ) - dace_helpers.redirect_edge(state=state, edge=edge_to_move, new_dst=to_node) + helpers.redirect_edge(state=state, edge=edge_to_move, new_dst=to_node) from_node.remove_in_connector(dmr_symbol) # There is no other edge that we have to consider, so we just end here @@ -202,10 +191,10 @@ def relocate_nodes( to_node.add_in_connector("IN_" + new_conn) for e in list(state.in_edges_by_connector(from_node, "IN_" + old_conn)): - dace_helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) + helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) to_node.add_out_connector("OUT_" + new_conn) for e in list(state.out_edges_by_connector(from_node, "OUT_" + old_conn)): - dace_helpers.redirect_edge( + helpers.redirect_edge( state, e, new_src=to_node, new_src_conn="OUT_" + new_conn ) from_node.remove_in_connector("IN_" + old_conn) @@ -213,13 +202,13 @@ def relocate_nodes( # Check if we succeeded. if state.out_degree(from_node) != 0: - raise dace_validation.InvalidSDFGError( + raise validation.InvalidSDFGError( f"Failed to relocate the outgoing edges from `{from_node}`, there are still `{state.out_edges(from_node)}`", sdfg, sdfg.node_id(state), ) if state.in_degree(from_node) != 0: - raise dace_validation.InvalidSDFGError( + raise validation.InvalidSDFGError( f"Failed to relocate the incoming edges from `{from_node}`, there are still `{state.in_edges(from_node)}`", sdfg, sdfg.node_id(state), @@ -229,8 +218,8 @@ def relocate_nodes( @staticmethod def map_parameter_compatible( - map_1: dace_nodes.Map, - map_2: dace_nodes.Map, + map_1: nodes.Map, + map_2: nodes.Map, state: Union[SDFGState, SDFG], sdfg: SDFG, ) -> bool: @@ -241,9 +230,9 @@ def map_parameter_compatible( is performed. - The ranges must be the same. """ - range_1: dace_subsets.Range = map_1.range + range_1: subsets.Range = map_1.range params_1: Sequence[str] = map_1.params - range_2: dace_subsets.Range = map_2.range + range_2: subsets.Range = map_2.range params_2: Sequence[str] = map_2.params # The maps are only fuseable if we have an exact match in the parameter names @@ -253,8 +242,8 @@ def map_parameter_compatible( return False # Maps the name of a parameter to the dimension index - param_dim_map_1: dict[str, int] = {pname: i for i, pname in enumerate(params_1)} - param_dim_map_2: dict[str, int] = {pname: i for i, pname in enumerate(params_2)} + param_dim_map_1: Dict[str, int] = {pname: i for i, pname in enumerate(params_1)} + param_dim_map_2: Dict[str, int] = {pname: i for i, pname in enumerate(params_2)} # To fuse the two maps the ranges must have the same ranges for pname in params_1: @@ -268,7 +257,7 @@ def map_parameter_compatible( def is_interstate_transient( self, - transient: Union[str, dace_nodes.AccessNode], + transient: Union[str, nodes.AccessNode], sdfg: dace.SDFG, state: dace.SDFGState, ) -> bool: @@ -294,14 +283,14 @@ def is_interstate_transient( """ # According to [rule 6](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG) - # the set of such transients is partially given by all source access dace_nodes. + # the set of such transients is partially given by all source access nodes. # Because of rule 3 we also include all scalars in this set, as an over # approximation. Furthermore, because simplify might violate rule 3, - # we also include the sink dace_nodes. + # we also include the sink nodes. # See if we have already computed the set if sdfg in self.shared_transients: - shared_sdfg_transients: set[str] = self.shared_transients[sdfg] + shared_sdfg_transients: Set[str] = self.shared_transients[sdfg] else: # SDFG is not known so we have to compute the set. shared_sdfg_transients = set() @@ -313,7 +302,7 @@ def is_interstate_transient( for node in itertools.chain( state_to_scan.source_nodes(), state_to_scan.sink_nodes() ) - if isinstance(node, dace_nodes.AccessNode) + if isinstance(node, nodes.AccessNode) and sdfg.arrays[node.data].transient ] ) @@ -326,13 +315,13 @@ def is_interstate_transient( assert len(matching_access_nodes) == 1 transient = matching_access_nodes[0] else: - assert isinstance(transient, dace_nodes.AccessNode) + assert isinstance(transient, nodes.AccessNode) name = transient.data - desc: dace_data.Data = sdfg.arrays[name] + desc: data.Data = sdfg.arrays[name] if not desc.transient: return True - if isinstance(desc, dace_data.Scalar): + if isinstance(desc, data.Scalar): return True # Scalars can not be removed by fusion anyway. # Rule 8: If degree larger than one then it is used within the state. @@ -346,13 +335,13 @@ def partition_first_outputs( self, state: SDFGState, sdfg: SDFG, - map_exit_1: dace_nodes.MapExit, - map_entry_2: dace_nodes.MapEntry, + map_exit_1: nodes.MapExit, + map_entry_2: nodes.MapEntry, ) -> Union[ - tuple[ - set[dace_graph.MultiConnectorEdge[dace.Memlet]], - set[dace_graph.MultiConnectorEdge[dace.Memlet]], - set[dace_graph.MultiConnectorEdge[dace.Memlet]], + Tuple[ + Set[graph.MultiConnectorEdge[dace.Memlet]], + Set[graph.MultiConnectorEdge[dace.Memlet]], + Set[graph.MultiConnectorEdge[dace.Memlet]], ], None, ]: @@ -387,16 +376,16 @@ def partition_first_outputs( map_entry_2: The entry node of the second map. """ # The three outputs set. - pure_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() - exclusive_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() - shared_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() + pure_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + exclusive_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + shared_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() # Set of intermediate nodes that we have already processed. - processed_inter_nodes: set[dace_nodes.Node] = set() + processed_inter_nodes: Set[nodes.Node] = set() # Now scan all output edges of the first exit and classify them for out_edge in state.out_edges(map_exit_1): - intermediate_node: dace_nodes.Node = out_edge.dst + intermediate_node: nodes.Node = out_edge.dst # We already processed the node, this should indicate that we should # run simplify again, or we should start implementing this case. @@ -450,13 +439,13 @@ def partition_first_outputs( # everything else we do not know how to handle. It is important that # we do not test for non transient data here, because they can be # handled has shared intermediates. - if not isinstance(intermediate_node, dace_nodes.AccessNode): + if not isinstance(intermediate_node, nodes.AccessNode): return None - intermediate_desc: dace_data.Data = intermediate_node.desc(sdfg) - if isinstance(intermediate_desc, dace_data.View): + intermediate_desc: data.Data = intermediate_node.desc(sdfg) + if isinstance(intermediate_desc, data.View): return None - # There are some restrictions we have on intermediate dace_nodes. The first one + # There are some restrictions we have on intermediate nodes. The first one # is that we do not allow WCR, this is because they need special handling # which is currently not implement (the DaCe transformation has this # restriction as well). The second one is that we can reduce the @@ -501,8 +490,8 @@ def partition_first_outputs( return None if consumer_node is map_entry_2: # Dynamic map range. return None - if isinstance(consumer_node, dace_nodes.LibraryNode): - # TODO(phimuell): Allow some library dace_nodes. + if isinstance(consumer_node, nodes.LibraryNode): + # TODO(phimuell): Allow some library nodes. return None # Note that "remove" has a special meaning here, regardless of the @@ -536,8 +525,8 @@ def partition_first_outputs( return None if consumer_node is map_entry_2: # Dynamic map range return None - if isinstance(consumer_node, dace_nodes.LibraryNode): - # TODO(phimuell): Allow some library dace_nodes. + if isinstance(consumer_node, nodes.LibraryNode): + # TODO(phimuell): Allow some library nodes. return None else: # Ensure that there is no path that leads to the second map. @@ -558,12 +547,12 @@ def partition_first_outputs( def is_nested_sdfg( - sdfg: Union[dace.SDFG, dace.SDFGState, dace_nodes.NestedSDFG], + sdfg: Union[dace.SDFG, dace.SDFGState, nodes.NestedSDFG], ) -> bool: """Tests if `sdfg` is a NestedSDFG.""" if isinstance(sdfg, dace.SDFGState): sdfg = sdfg.parent - if isinstance(sdfg, dace_nodes.NestedSDFG): + if isinstance(sdfg, nodes.NestedSDFG): return True elif isinstance(sdfg, dace.SDFG): if sdfg.parent_nsdfg_node is not None: @@ -574,11 +563,11 @@ def is_nested_sdfg( def all_nodes_between( - graph: dace.SDFG | dace.SDFGState, - begin: dace_nodes.Node, - end: dace_nodes.Node, + graph: Union[dace.SDFG, dace.SDFGState], + begin: nodes.Node, + end: nodes.Node, reverse: bool = False, -) -> set[dace_nodes.Node] | None: +) -> Union[Set[nodes.Node], None]: """Find all nodes that are reachable from `begin` but bound by `end`. Essentially the function starts a DFS at `begin`. If an edge is found that lead @@ -601,7 +590,7 @@ def all_nodes_between( `begin` and ends at a node that is not `end`. """ - def next_nodes(node: dace_nodes.Node) -> Iterable[dace_nodes.Node]: + def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: if reverse: return (edge.src for edge in graph.in_edges(node)) return (edge.dst for edge in graph.out_edges(node)) @@ -609,12 +598,12 @@ def next_nodes(node: dace_nodes.Node) -> Iterable[dace_nodes.Node]: if reverse: begin, end = end, begin - to_visit: list[dace_nodes.Node] = [begin] - seen: set[dace_nodes.Node] = set() + to_visit: List[nodes.Node] = [begin] + seen: Set[nodes.Node] = set() found_end: bool = False while len(to_visit) > 0: - n: dace_nodes.Node = to_visit.pop() + n: nodes.Node = to_visit.pop() if n == end: found_end = True continue @@ -631,9 +620,9 @@ def next_nodes(node: dace_nodes.Node) -> Iterable[dace_nodes.Node]: def is_parallel( - graph: dace.SDFG | dace.SDFGState, - node1: dace_nodes.Node, - node2: dace_nodes.Node, + graph: Union[dace.SDFG, dace.SDFGState], + node1: nodes.Node, + node2: nodes.Node, ) -> bool: """Tests if `node1` and `node2` are parallel. @@ -657,10 +646,10 @@ def is_parallel( def find_downstream_consumers( state: dace.SDFGState, - begin: dace_nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], + begin: Union[nodes.Node, graph.MultiConnectorEdge[dace.Memlet]], only_tasklets: bool = False, reverse: bool = False, -) -> set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: +) -> Set[Tuple[nodes.Node, graph.MultiConnectorEdge[dace.Memlet]]]: """Find all downstream connectors of `begin`. A consumer, in for this function, is any node that is neither an entry nor @@ -679,24 +668,24 @@ def find_downstream_consumers( only_tasklets: Return only Tasklets. reverse: Follow the reverse direction. """ - if isinstance(begin, dace_graph.MultiConnectorEdge): - to_visit: list[dace_graph.MultiConnectorEdge[dace.Memlet]] = [begin] + if isinstance(begin, graph.MultiConnectorEdge): + to_visit: List[graph.MultiConnectorEdge[dace.Memlet]] = [begin] elif reverse: to_visit = list(state.in_edges(begin)) else: to_visit = list(state.out_edges(begin)) - seen: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() - found: set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]] = set() + seen: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + found: Set[Tuple[nodes.Node, graph.MultiConnectorEdge[dace.Memlet]]] = set() while len(to_visit) != 0: - curr_edge: dace_graph.MultiConnectorEdge[dace.Memlet] = to_visit.pop() - next_node: dace_nodes.Node = curr_edge.src if reverse else curr_edge.dst + curr_edge: graph.MultiConnectorEdge[dace.Memlet] = to_visit.pop() + next_node: nodes.Node = curr_edge.src if reverse else curr_edge.dst if curr_edge in seen: continue seen.add(curr_edge) - if isinstance(next_node, (dace_nodes.MapEntry, dace_nodes.MapExit)): + if isinstance(next_node, (nodes.MapEntry, nodes.MapExit)): if reverse: target_conn = curr_edge.src_conn[4:] new_edges = state.in_edges_by_connector(curr_edge.src, "IN_" + target_conn) @@ -704,7 +693,7 @@ def find_downstream_consumers( # In forward mode a Map entry could also mean the definition of a # dynamic map range. if (not curr_edge.dst_conn.startswith("IN_")) and isinstance( - next_node, dace_nodes.MapEntry + next_node, nodes.MapEntry ): # This edge defines a dynamic map range, which is a consumer if not only_tasklets: @@ -715,7 +704,7 @@ def find_downstream_consumers( to_visit.extend(new_edges) del new_edges else: - if only_tasklets and (not isinstance(next_node, dace_nodes.Tasklet)): + if only_tasklets and (not isinstance(next_node, nodes.Tasklet)): continue found.add((next_node, curr_edge)) @@ -724,9 +713,9 @@ def find_downstream_consumers( def find_upstream_producers( state: dace.SDFGState, - begin: dace_nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], + begin: Union[nodes.Node, graph.MultiConnectorEdge[dace.Memlet]], only_tasklets: bool = False, -) -> set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: +) -> Set[Tuple[nodes.Node, graph.MultiConnectorEdge[dace.Memlet]]]: """Same as `find_downstream_consumers()` but with `reverse` set to `True`.""" return find_downstream_consumers( state=state, diff --git a/dace/transformation/dataflow/map_fusion_parallel.py b/dace/transformation/dataflow/map_fusion_parallel.py index 179227d23f..0c032cc5f2 100644 --- a/dace/transformation/dataflow/map_fusion_parallel.py +++ b/dace/transformation/dataflow/map_fusion_parallel.py @@ -2,15 +2,15 @@ """Implements the parallel map fusing transformation.""" -from typing import Any, Optional, Union +from typing import Any, Optional, Set, Union import dace -from dace import properties as dace_properties, transformation as dace_transformation -from dace.sdfg import SDFG, SDFGState, graph as dace_graph, nodes as dace_nodes +from dace import properties, transformation +from dace.sdfg import SDFG, SDFGState, graph, nodes from dace.transformation.dataflow import map_fusion_helper -@dace_properties.make_properties +@properties.make_properties class ParallelMapFusion(map_fusion_helper.MapFusionHelper): """The `ParallelMapFusion` transformation allows to merge two parallel maps together. @@ -29,10 +29,10 @@ class ParallelMapFusion(map_fusion_helper.MapFusionHelper): modify the exit nodes of the Map. """ - map_entry1 = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) - map_entry2 = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + map_entry1 = transformation.transformation.PatternNode(nodes.MapEntry) + map_entry2 = transformation.transformation.PatternNode(nodes.MapEntry) - only_if_common_ancestor = dace_properties.Property( + only_if_common_ancestor = properties.Property( dtype=bool, default=False, allow_none=False, @@ -51,7 +51,7 @@ def __init__( @classmethod def expressions(cls) -> Any: # This just matches _any_ two Maps inside a state. - state = dace_graph.OrderedMultiDiConnectorGraph() + state = graph.OrderedMultiDiConnectorGraph() state.add_nodes_from([cls.map_entry1, cls.map_entry2]) return [state] @@ -63,8 +63,8 @@ def can_be_applied( permissive: bool = False, ) -> bool: """The transformation is applicable.""" - map_entry_1: dace_nodes.MapEntry = self.map_entry1 - map_entry_2: dace_nodes.MapEntry = self.map_entry2 + map_entry_1: nodes.MapEntry = self.map_entry1 + map_entry_2: nodes.MapEntry = self.map_entry2 # Check the structural properties of the maps, this will also ensure that # the two maps are in the same scope. @@ -86,7 +86,7 @@ def can_be_applied( # Test if they have they share a node as direct ancestor. if self.only_if_common_ancestor: # This assumes that there is only one access node per data container in the state. - ancestors_1: set[dace_nodes.Node] = {e1.src for e1 in graph.in_edges(map_entry_1)} + ancestors_1: Set[nodes.Node] = {e1.src for e1 in graph.in_edges(map_entry_1)} if not any(e2.src in ancestors_1 for e2 in graph.in_edges(map_entry_2)): return False @@ -101,10 +101,10 @@ def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: """ assert self.map_parameter_compatible(self.map_entry1.map, self.map_entry2.map, graph, sdfg) - map_entry_1: dace_nodes.MapEntry = self.map_entry1 - map_exit_1: dace_nodes.MapExit = graph.exit_node(map_entry_1) - map_entry_2: dace_nodes.MapEntry = self.map_entry2 - map_exit_2: dace_nodes.MapExit = graph.exit_node(map_entry_2) + map_entry_1: nodes.MapEntry = self.map_entry1 + map_exit_1: nodes.MapExit = graph.exit_node(map_entry_1) + map_entry_2: nodes.MapEntry = self.map_entry2 + map_exit_2: nodes.MapExit = graph.exit_node(map_entry_2) for to_node, from_node in zip((map_entry_1, map_exit_1), (map_entry_2, map_exit_2)): self.relocate_nodes( diff --git a/dace/transformation/dataflow/map_fusion_serial.py b/dace/transformation/dataflow/map_fusion_serial.py index 0e7dbe014b..3649d1a335 100644 --- a/dace/transformation/dataflow/map_fusion_serial.py +++ b/dace/transformation/dataflow/map_fusion_serial.py @@ -3,22 +3,16 @@ """Implements the serial map fusing transformation.""" import copy -from typing import Any, Union +from typing import Any, Dict, List, Set, Union import dace -from dace import ( - dtypes as dace_dtypes, - properties as dace_properties, - subsets as dace_subsets, - symbolic as dace_symbolic, - transformation as dace_transformation, -) -from dace.sdfg import SDFG, SDFGState, graph as dace_graph, nodes as dace_nodes +from dace import dtypes, properties, subsets, symbolic, transformation +from dace.sdfg import SDFG, SDFGState, graph, nodes from dace.transformation.dataflow import map_fusion_helper -@dace_properties.make_properties +@properties.make_properties class SerialMapFusion(map_fusion_helper.MapFusionHelper): """Specialized replacement for the map fusion transformation that is provided by DaCe. @@ -48,9 +42,9 @@ class SerialMapFusion(map_fusion_helper.MapFusionHelper): - This transformation modifies more nodes than it matches! """ - map_exit1 = dace_transformation.transformation.PatternNode(dace_nodes.MapExit) - access_node = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) - map_entry2 = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + map_exit1 = transformation.transformation.PatternNode(nodes.MapExit) + access_node = transformation.transformation.PatternNode(nodes.AccessNode) + map_entry2 = transformation.transformation.PatternNode(nodes.MapEntry) def __init__( self, @@ -84,10 +78,10 @@ def can_be_applied( - The decomposition exists and at least one of the intermediate sets is not empty. """ - assert isinstance(self.map_exit1, dace_nodes.MapExit) - assert isinstance(self.map_entry2, dace_nodes.MapEntry) - map_entry_1: dace_nodes.MapEntry = graph.entry_node(self.map_exit1) - map_entry_2: dace_nodes.MapEntry = self.map_entry2 + assert isinstance(self.map_exit1, nodes.MapExit) + assert isinstance(self.map_entry2, nodes.MapEntry) + map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit1) + map_entry_2: nodes.MapEntry = self.map_entry2 # This essentially test the structural properties of the two Maps. if not self.can_be_fused( @@ -128,14 +122,14 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non # once we start adding and removing nodes it seems that their ID changes. # Thus we have to save them here, this is a known behaviour in DaCe. assert isinstance(graph, dace.SDFGState) - assert isinstance(self.map_exit1, dace_nodes.MapExit) - assert isinstance(self.map_entry2, dace_nodes.MapEntry) + assert isinstance(self.map_exit1, nodes.MapExit) + assert isinstance(self.map_entry2, nodes.MapEntry) assert self.map_parameter_compatible(self.map_exit1.map, self.map_entry2.map, graph, sdfg) - map_exit_1: dace_nodes.MapExit = self.map_exit1 - map_entry_2: dace_nodes.MapEntry = self.map_entry2 - map_exit_2: dace_nodes.MapExit = graph.exit_node(self.map_entry2) - map_entry_1: dace_nodes.MapEntry = graph.entry_node(self.map_exit1) + map_exit_1: nodes.MapExit = self.map_exit1 + map_entry_2: nodes.MapEntry = self.map_entry2 + map_exit_2: nodes.MapExit = graph.exit_node(self.map_entry2) + map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit1) output_partition = self.partition_first_outputs( state=graph, @@ -194,12 +188,12 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non @staticmethod def handle_intermediate_set( - intermediate_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]], + intermediate_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]], state: SDFGState, sdfg: SDFG, - map_exit_1: dace_nodes.MapExit, - map_entry_2: dace_nodes.MapEntry, - map_exit_2: dace_nodes.MapExit, + map_exit_1: nodes.MapExit, + map_entry_2: nodes.MapEntry, + map_exit_2: nodes.MapExit, is_exclusive_set: bool, ) -> None: """This function handles the intermediate sets. @@ -233,14 +227,14 @@ def handle_intermediate_set( # To fix that we will create this replacement dict that will replace all # occurrences of the iteration variables of the second map with zero. # Note that this is still not enough as the dimensionality might be different. - memlet_repl: dict[str, int] = {str(param): 0 for param in map_entry_2.map.params} + memlet_repl: Dict[str, int] = {str(param): 0 for param in map_entry_2.map.params} # Now we will iterate over all intermediate edges and process them. # If not stated otherwise the comments assume that we run in exclusive mode. for out_edge in intermediate_outputs: # This is the intermediate node that, that we want to get rid of. # In shared mode we want to recreate it after the second map. - inter_node: dace_nodes.AccessNode = out_edge.dst + inter_node: nodes.AccessNode = out_edge.dst inter_name = inter_node.data inter_desc = inter_node.desc(sdfg) inter_shape = inter_desc.shape @@ -253,12 +247,12 @@ def handle_intermediate_set( if len(pre_exit_edges) != 1: raise NotImplementedError() pre_exit_edge = pre_exit_edges[0] - new_inter_shape_raw = dace_symbolic.overapproximate(pre_exit_edge.data.subset.size()) + new_inter_shape_raw = symbolic.overapproximate(pre_exit_edge.data.subset.size()) # Over approximation will leave us with some unneeded size one dimensions. # That are known to cause some troubles, so we will now remove them. - squeezed_dims: list[int] = [] # These are the dimensions we removed. - new_inter_shape: list[int] = [] # This is the final shape of the new intermediate. + squeezed_dims: List[int] = [] # These are the dimensions we removed. + new_inter_shape: List[int] = [] # This is the final shape of the new intermediate. for dim, (proposed_dim_size, full_dim_size) in enumerate( zip(new_inter_shape_raw, inter_shape) ): @@ -284,7 +278,7 @@ def handle_intermediate_set( new_inter_name, dtype=inter_desc.dtype, transient=True, - storage=dace_dtypes.StorageType.Register, + storage=dtypes.StorageType.Register, find_new_name=True, ) @@ -299,7 +293,7 @@ def handle_intermediate_set( dtype=inter_desc.dtype, find_new_name=True, ) - new_inter_node: dace_nodes.AccessNode = state.add_access(new_inter_name) + new_inter_node: nodes.AccessNode = state.add_access(new_inter_name) # New we will reroute the output Memlet, thus it will no longer pass # through the Map exit but through the newly created intermediate. @@ -357,7 +351,7 @@ def handle_intermediate_set( # the input connectors on the map entry, such that we know where we # have to reroute inside the Map. # NOTE: Assumes that map (if connected is the direct neighbour). - conn_names: set[str] = set() + conn_names: Set[str] = set() for inter_node_out_edge in state.out_edges(inter_node): if inter_node_out_edge.dst == map_entry_2: assert inter_node_out_edge.dst_conn.startswith("IN_") @@ -446,7 +440,7 @@ def handle_intermediate_set( assert new_exit_memlet.data == inter_name new_exit_memlet.subset = pre_exit_edge.data.dst_subset new_exit_memlet.other_subset = ( - "0" if is_scalar else dace_subsets.Range.from_array(inter_desc) + "0" if is_scalar else subsets.Range.from_array(inter_desc) ) new_pre_exit_conn = map_exit_2.next_connector() From bc87ddb79354175dcbb08e796b5caeb916016592 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 23 Aug 2024 13:04:26 +0200 Subject: [PATCH 003/115] Added a function for estimating if something is pointwhise. But it is too restrictive. --- dace/transformation/dataflow/__init__.py | 2 +- dace/transformation/dataflow/map_fusion.py | 1 + .../dataflow/map_fusion_helper.py | 207 ++++++++++++++---- 3 files changed, 169 insertions(+), 41 deletions(-) diff --git a/dace/transformation/dataflow/__init__.py b/dace/transformation/dataflow/__init__.py index dbd3838d9f..9316949d70 100644 --- a/dace/transformation/dataflow/__init__.py +++ b/dace/transformation/dataflow/__init__.py @@ -8,7 +8,7 @@ from .map_for_loop import MapToForLoop, MapToForLoopRegion from .map_interchange import MapInterchange from .map_dim_shuffle import MapDimShuffle -from .map_fusion import MapFusion, ParallelMapFusion, SerialMapFusion +from .map_fusion import MapFusion, ParallelMapFusion, SerialMapFusion, MapFusionOriginal from .map_fission import MapFission from .map_unroll import MapUnroll from .trivial_map_elimination import TrivialMapElimination diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index 3735d3e7dc..c0b458665e 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -3,6 +3,7 @@ from .map_fusion_serial import SerialMapFusion from .map_fusion_parallel import ParallelMapFusion +from .map_fusion_original import MapFusionOriginal # Compatibility with previous versions of DaCe and clients. MapFusion = SerialMapFusion diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index 81fd9f054c..a6724f5010 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -4,7 +4,8 @@ import functools import itertools -from typing import Any, Dict, Iterable, List, Optional, Set, Sequence, Tuple, Union +import re +from typing import Any, Dict, Iterable, List, Literal, Optional, Set, Sequence, Tuple, Union, overload import dace from dace import data, properties, subsets, transformation @@ -380,6 +381,11 @@ def partition_first_outputs( exclusive_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() shared_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + # These are the iteration parameters of the two maps. + # They are not yet modified, that they match each other. + map_params_1: Sequence[str] = map_exit_1.map.params + map_params_2: Sequence[str] = map_entry_2.map.params + # Set of intermediate nodes that we have already processed. processed_inter_nodes: Set[nodes.Node] = set() @@ -390,6 +396,7 @@ def partition_first_outputs( # We already processed the node, this should indicate that we should # run simplify again, or we should start implementing this case. if intermediate_node in processed_inter_nodes: + print(f"399") return None processed_inter_nodes.add(intermediate_node) @@ -413,55 +420,69 @@ def partition_first_outputs( # cases, as handling them is essentially rerouting an edge, whereas # handling intermediate nodes is much more complicated. + # For us an intermediate node must always be an access node, because + # everything else we do not know how to handle. It is important that + # we do not test for non transient data here, because they can be + # handled has shared intermediates. + if not isinstance(intermediate_node, nodes.AccessNode): + print(f"428") + return None + intermediate_desc: data.Data = intermediate_node.desc(sdfg) + if isinstance(intermediate_desc, data.View): + print(f"432") + return None + # Empty Memlets are only allowed if they are in `\mathbb{P}`, which # is also the only place they really make sense (for a map exit). # Thus if we now found an empty Memlet we reject it. if out_edge.data.is_empty(): + print(f"out_endge empty.") return None - # In case the intermediate has more than one entry, all must come from the - # first map, otherwise we can not fuse them. Currently we restrict this - # even further by saying that it has only one incoming Memlet. + # The intermediate now can only have a single source. It might be possible + # to extend this to many inputs as long as they come from the top map. + # NOTE: The output degree is checked implicitly further down, the + # general rule is, that multiple outputs are only allowed if only + # one enters the second Map, the other output must go to different + # consumers, in which case the node is a shared intermediate. if state.in_degree(intermediate_node) != 1: + print(f"449") return None # It can happen that multiple edges converges at the `IN_` connector # of the first map exit, but there is only one edge leaving the exit. # It is complicate to handle this, so for now we ignore it. # TODO(phimuell): Handle this case properly. + # The main reason why we forbid this is because it becomes a bit tricky + # to figuring out the size of the intermediate. inner_collector_edges = list( state.in_edges_by_connector(intermediate_node, "IN_" + out_edge.src_conn[3:]) ) if len(inner_collector_edges) > 1: + print(f"469") return None - # For us an intermediate node must always be an access node, because - # everything else we do not know how to handle. It is important that - # we do not test for non transient data here, because they can be - # handled has shared intermediates. - if not isinstance(intermediate_node, nodes.AccessNode): - return None - intermediate_desc: data.Data = intermediate_node.desc(sdfg) - if isinstance(intermediate_desc, data.View): - return None + # An important assumption we made for fusion is that the data is "point + # wise interchangeable/compatible", for a more involved definition see + # `is_pointwise_subset()`. We will now check this for the "producer side" + # (the consumer side is handled later). There is an important point here, + # in case the new intermediate is only a scalar, then this is completely + # safe. Due to the fact how a Map is defined in SDFG. If the new + # intermediate is not a scalar, such as `A[i, j, :]` in `Map[i=..., j=...]` + # then it is a bit of a gamble and to be fully sure we would need to look + # at the consumer subset, however, these should be edge cases. + # TODO(phimuell): Use the `param_association` to evaluate which dimensions + # are actually used and store this here, below use this to check if the + # same dimensions are accessed by the consumer. + for inner_collector_edge in inner_collector_edges: + if not is_pointwise_subset(inner_collector_edge.data.dst_subset, map_params_1): + print(f"479") + return None - # There are some restrictions we have on intermediate nodes. The first one - # is that we do not allow WCR, this is because they need special handling - # which is currently not implement (the DaCe transformation has this - # restriction as well). The second one is that we can reduce the - # intermediate node and only feed a part into the second map, consider - # the case `b = a + 1; return b + 2`, where we have arrays. In this - # example only a single element must be available to the second map. - # However, this is hard to check so we will make a simplification. - # First, we will not check it at the producer, but at the consumer point. - # There we assume if the consumer does _not consume the whole_ - # intermediate array, then we can decompose the intermediate, by setting - # the map iteration index to zero and recover the shape, see - # implementation in the actual fusion routine. - # This is an assumption that is in most cases correct, but not always. - # However, doing it correctly is extremely complex. + # Another restriction we impose is that we do not allow WCR. for _, produce_edge in map_fusion_helper.find_upstream_producers(state, out_edge): if produce_edge.data.wcr is not None: + print(f"485") return None if len(downstream_nodes) == 0: @@ -469,29 +490,30 @@ def partition_first_outputs( # second map, thus the edge belongs either in `\mathbb{S}` or # `\mathbb{E}`. - # This is a very special situation, i.e. the access node has many - # different connections to the second map entry, this is a special - # case that we do not handle. + # If the intermediate access node as more than one outgoing edge + # it means (because of `downstream_nodes`) that it has multiple + # connections to the second map. We do not allow this. # TODO(phimuell): Handle this case. if state.out_degree(intermediate_node) != 1: + print(f"489") return None - # Certain nodes need more than one element as input. As explained - # above, in this situation we assume that we can naturally decompose - # them iff the node does not consume that whole intermediate. - # Furthermore, it can not be a dynamic map range or a library node. - intermediate_size = functools.reduce(lambda a, b: a * b, intermediate_desc.shape) + # We now look at the consumers, as above we assume that the consumption. + # is point wise, however, we allow multiple consumer. As written + # above is safe if the new intermediate is a scalar, in case of an + # array it is pretty safe (see todo above). + # Furthermore, we disallow certain type of consumer. consumers = map_fusion_helper.find_downstream_consumers(state=state, begin=intermediate_node) for consumer_node, feed_edge in consumers: - # TODO(phimuell): Improve this approximation. - if ( - intermediate_size != 1 - ) and feed_edge.data.num_elements() == intermediate_size: + if not is_pointwise_subset(feed_edge.data.src_subset, map_params_2): + print(f"399// {feed_edge.data.src_subset} | {map_params_2}") return None if consumer_node is map_entry_2: # Dynamic map range. + print(f"399_") return None if isinstance(consumer_node, nodes.LibraryNode): # TODO(phimuell): Allow some library nodes. + print(f"399__") return None # Note that "remove" has a special meaning here, regardless of the @@ -546,6 +568,111 @@ def partition_first_outputs( return (pure_outputs, exclusive_outputs, shared_outputs) +@overload +def is_pointwise_subset( + subset: subsets.Range, + map_params: List[str], + param_association: Literal[False], +) -> bool: + ... + + +@overload +def is_pointwise_subset( + subset: subsets.Range, + map_params: List[str], + param_association: Literal[True], +) -> Optional[List[int]]: + ... + + +def is_pointwise_subset( + subset: subsets.Range, + map_params: List[str], + param_association: bool = False, +) -> bool: + """Tests if `subset` is "point wise" with respect to map parameters `map_params`. + + Essentially a subset is point wise, with respect to map parameters, if it access + the data in a `A[i, j]` manner. An example for a not point wise access would be + `A[i + 1, j]`. However, there are some special cases: + - All map parameters must be used, For example the expression `A[i, :]`, inside + the map `Map[i=0:N, j=0:M]` is not point wise, because `j` is not used. + - On the other hand if `A` is a 3D array then expressions such as `A[i, :, j]` + or `A[i, 3, j]` would be point wise. Although they are not a scalar. + - Furthermore, all parameters must appear exactly once, i.e. accesses such as + `A[i, i]`, even inside `Map[i=0:N]` is not point wise. + + It is important to realize that point wise is a very powerful property, since + it essentially releases us from the check of the order of the parameter. + However, there are some cases were it might fail. + + If the `param_association` argument is set to `True` the function will return the + parameter association, This is a list of integer, that indicates which parameter + was found in which dimension of the subset. + If the subset is point wise the function will return `None`. + + Args: + subset: The subset to inspect. + map_params: The list of parameters to inspect. + param_association: Return the parameter association. + """ + map_patterns = [re.compile(f"\\b{str(map_param)}\\b") for map_param in map_params] + subset_sizes = subset.size_exact() + unused_params = set(map_params) + parameter_to_dim_map: Dict[str, int] = dict() + + # Now go through each dimension of the subset and inspect them. + for dim in range(subset.dims()): + if(subset_sizes[dim] == 1): + # Only a single element is consumed, thus we must test if the access + # is done through a yet unused map parameter only. + ss_idx = str(subset[dim][0]) + for map_param, map_pattern in zip(map_params, map_patterns): + if(ss_idx == map_param): + # The map parameter is used alone without any additions. + if(map_param not in unused_params): + # The map parameter was already used, so we have something + # like `A[i, i]`. Thus it is not point wise! + return None if param_association else False + + # The parameter is used alone, so this is point wise. + unused_params.discard(map_param) + parameter_to_dim_map[map_param] = dim + break + + elif(map_pattern.match(ss_idx)): + # The parameter matches partially, e.g. `A[i + 1]`, and is not point wise + return None if param_association else False + + # If we here then `ss_idx` did not depend in any way on the map parameters. + # This is the case if it is a literal or an other symbol, but we know that + # it is constant (because of how symbols work). If it is really point wise + # depends on if all symbols are consumed. + + elif(subset_sizes[dim] == 0): + # This is a strange case that we ignore but it does not violate point wise. + pass + + else: + # There are multiple elements that are consumed. An example would be + # expressions such as `A[i, :, j]` again for a 2D Map. For now we allow + # them, but it is a bit dangerous to do this because it only works if + # the other map also processed that that with that expression. + # This is a fair assumption. + for ss_element in map(str, subset[dim]): + if any(map_pattern.match(ss) for ss in ss_element): + return None if param_association else False + + # Not all parameters were used, so it is not point wise + if(len(unused_params) != 0): + return None if param_association else False + + if(param_association): + return [parameter_to_dim_map[map_param] for map_param in map_params] + return True + + def is_nested_sdfg( sdfg: Union[dace.SDFG, dace.SDFGState, nodes.NestedSDFG], ) -> bool: From 497a2d6569258f7ca82255ab8cb035b727a6ca42 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 23 Aug 2024 14:22:55 +0200 Subject: [PATCH 004/115] Now there is an error in the actuall rewiering stuff. --- .../dataflow/map_fusion_helper.py | 276 ++++++++---------- 1 file changed, 123 insertions(+), 153 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index a6724f5010..6d5ff355dd 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -8,10 +8,9 @@ from typing import Any, Dict, Iterable, List, Literal, Optional, Set, Sequence, Tuple, Union, overload import dace -from dace import data, properties, subsets, transformation +from dace import data, properties, subsets, transformation, symbolic from dace.sdfg import SDFG, SDFGState, graph, nodes, validation from dace.transformation import helpers -from dace.transformation.dataflow import map_fusion_helper @properties.make_properties class MapFusionHelper(transformation.SingleStateTransformation): @@ -111,7 +110,7 @@ def can_be_fused( if scope[map_entry_1] is not None: return False # TODO(phimuell): Figuring out why this is here. - elif map_fusion_helper.is_nested_sdfg(sdfg): + elif is_nested_sdfg(sdfg): return False # We will now check if there exists a "remapping" that we can use. @@ -256,6 +255,90 @@ def map_parameter_compatible( return True + @staticmethod + def find_perameter_remapping( + first_map: nodes.Map, + second_map: nodes.Map + ) -> Union[Dict[str, str], None]: + """Computes the parameter remapping for the parameters of the _second_ map. + + The returned `dict` maps the parameters of the second map (keys) to parameter + names of the first map (values). Because of how the replace function works + the `dict` describes how to replace the parameters of the second map + with parameters of the first map. + + If the remapping does not exists, the function will return `None`. + Parameters, that already have the correct names, will not be included in the + final mapping. + + Args: + first_map: The first map (these parameters will be replaced). + second_map: The second map, these parameters acts as source. + """ + + # The parameter names + first_params: List[str] = first_map.params + second_params: List[str] = second_map.params + + if len(first_params) != len(second_params): + return None + + # Allows to map a name to the corresponding index + first_pmap: Dict[str, int] = {map_param: i for i, map_param in enumerate(first_params)} + second_pmap: Dict[str, int] = {map_param: i for i, map_param in enumerate(second_map)} + + # The ranges, however, we apply some post processing to them. + simp = lambda e: symbolic.simplify_ext(symbolic.simplify(e)) + first_rngs: List[Tuple[Any, Any, Any]] = [ (simp(rng[0]), simp(rng[1]), simp(rng[2])) for rng in first_map.range] + second_rngs: List[Tuple[Any, Any, Any]] = [ (simp(rng[0]), simp(rng[1]), simp(rng[2])) for rng in second_map.range] + + # These are the parameters of the second map that have not yet associated to + # a parameter of the second map. + unmapped_second_params: Set[str] = set(second_params) + unused_first_params: Set[str] = set(first_params) + + # This is the result (`second_param -> first_param`), note that if no renaming + # is needed then the parameter is not present in the mapping. + final_mapping: Dict[str, str] = {} + + # First we check if a remapping is needed at all, for this we look at the + # parameter that are the same in both maps. + for param in set(first_params).intersection(second_params): + first_rng = first_rngs[first_pmap[param]] + second_rng = second_rngs[second_pmap[param]] + + if first_rng == second_rng: + # They have the same name and the same range, this is already a match. + # Because the names are already the same, we do not have to enter them + # in the `final_mapping` + unmapped_second_params.pop(param) + unused_first_params.pop(param) + + # Check if no remapping is needed. + if len(unmapped_second_params) == 0: + return {} + + # Now we go through all the parameters that we have not mapped yet. + # All of them will result in a remapping. + for unmapped_second_param in unmapped_second_params: + second_rng = second_rngs[second_pmap[unmapped_second_param]] + assert unmapped_second_param not in final_mapping + + # Now look in all not yet used parameters of the first map which to use. + for candidate_param in unused_first_params: + candidate_rng = first_rngs[first_pmap[candidate_param]] + if candidate_rng == second_rng: + final_mapping[unmapped_second_param] = candidate_param + break + else: + # We did not find a candidate, so the remapping does not exist + return None + + unused_first_params.pop(final_mapping[unmapped_second_param]) + + assert len(unused_first_params) == 0 + return final_mapping + def is_interstate_transient( self, transient: Union[str, nodes.AccessNode], @@ -396,13 +479,12 @@ def partition_first_outputs( # We already processed the node, this should indicate that we should # run simplify again, or we should start implementing this case. if intermediate_node in processed_inter_nodes: - print(f"399") return None processed_inter_nodes.add(intermediate_node) # Now let's look at all nodes that are downstream of the intermediate node. # This, among other things, will tell us, how we have to handle this node. - downstream_nodes = map_fusion_helper.all_nodes_between( + downstream_nodes = all_nodes_between( graph=state, begin=intermediate_node, end=map_entry_2, @@ -425,18 +507,15 @@ def partition_first_outputs( # we do not test for non transient data here, because they can be # handled has shared intermediates. if not isinstance(intermediate_node, nodes.AccessNode): - print(f"428") return None intermediate_desc: data.Data = intermediate_node.desc(sdfg) if isinstance(intermediate_desc, data.View): - print(f"432") return None # Empty Memlets are only allowed if they are in `\mathbb{P}`, which # is also the only place they really make sense (for a map exit). # Thus if we now found an empty Memlet we reject it. if out_edge.data.is_empty(): - print(f"out_endge empty.") return None # The intermediate now can only have a single source. It might be possible @@ -446,7 +525,6 @@ def partition_first_outputs( # one enters the second Map, the other output must go to different # consumers, in which case the node is a shared intermediate. if state.in_degree(intermediate_node) != 1: - print(f"449") return None # It can happen that multiple edges converges at the `IN_` connector @@ -459,31 +537,29 @@ def partition_first_outputs( state.in_edges_by_connector(intermediate_node, "IN_" + out_edge.src_conn[3:]) ) if len(inner_collector_edges) > 1: - print(f"469") return None - # An important assumption we made for fusion is that the data is "point - # wise interchangeable/compatible", for a more involved definition see - # `is_pointwise_subset()`. We will now check this for the "producer side" - # (the consumer side is handled later). There is an important point here, - # in case the new intermediate is only a scalar, then this is completely - # safe. Due to the fact how a Map is defined in SDFG. If the new - # intermediate is not a scalar, such as `A[i, j, :]` in `Map[i=..., j=...]` - # then it is a bit of a gamble and to be fully sure we would need to look - # at the consumer subset, however, these should be edge cases. - # TODO(phimuell): Use the `param_association` to evaluate which dimensions - # are actually used and store this here, below use this to check if the - # same dimensions are accessed by the consumer. - for inner_collector_edge in inner_collector_edges: - if not is_pointwise_subset(inner_collector_edge.data.dst_subset, map_params_1): - print(f"479") - return None + # We now look for all producers inside the top map. We need this information + # later to determine if the point wise output of the top map is enough + # to serve as (point wise) input for the second map. + # In addition we will check if there is a WCR edge is present, which we + # can not handle. + producers = find_upstream_producers(state, out_edge) + + # More than one producer is not supported. + if len(producers) != 1: + return None + producer_node, producer_edge = next(iter(producers)) + + # If the producer is a view, then we give up. It is possible to handle, + # but also very complicated. + if isinstance(producer_node, nodes.AccessNode) and isinstance(producer_node.desc(sdfg), data.View): + return None + + # We do not allow that the edge has WCR. + if producer_edge.data.wcr is not None: + return None - # Another restriction we impose is that we do not allow WCR. - for _, produce_edge in map_fusion_helper.find_upstream_producers(state, out_edge): - if produce_edge.data.wcr is not None: - print(f"485") - return None if len(downstream_nodes) == 0: # There is nothing between intermediate node and the entry of the @@ -495,25 +571,23 @@ def partition_first_outputs( # connections to the second map. We do not allow this. # TODO(phimuell): Handle this case. if state.out_degree(intermediate_node) != 1: - print(f"489") return None - # We now look at the consumers, as above we assume that the consumption. - # is point wise, however, we allow multiple consumer. As written - # above is safe if the new intermediate is a scalar, in case of an - # array it is pretty safe (see todo above). - # Furthermore, we disallow certain type of consumer. - consumers = map_fusion_helper.find_downstream_consumers(state=state, begin=intermediate_node) - for consumer_node, feed_edge in consumers: - if not is_pointwise_subset(feed_edge.data.src_subset, map_params_2): - print(f"399// {feed_edge.data.src_subset} | {map_params_2}") + # We can fuse the maps only if the producer, i.e. the top map, + # represented by `producer_edge`, is gives us enough information. + producer_subset: subsets.Range = producer_edge.data.dst_subset + consumers = find_downstream_consumers(state=state, begin=intermediate_node) + + for consumer_node, consumer_edge in consumers: + assert self.map_parameter_compatible(map_exit_1.map, map_entry_2.map, state, sdfg) + # Tests if consumer consume less or equal the amount we generate. + if not producer_subset.covers(consumer_edge.data.src_subset): + print(f"NOT COVER | {producer_subset} | {consumer_edge.data.src_subset}") return None if consumer_node is map_entry_2: # Dynamic map range. - print(f"399_") return None if isinstance(consumer_node, nodes.LibraryNode): # TODO(phimuell): Allow some library nodes. - print(f"399__") return None # Note that "remove" has a special meaning here, regardless of the @@ -535,24 +609,25 @@ def partition_first_outputs( # fulfills the restriction outlined above. # - All other connections have no connection to the second map. found_second_entry = False - intermediate_size = functools.reduce(lambda a, b: a * b, intermediate_desc.shape) for edge in state.out_edges(intermediate_node): if edge.dst is map_entry_2: if found_second_entry: # The second map was found again. return None found_second_entry = True - consumers = map_fusion_helper.find_downstream_consumers(state=state, begin=edge) + consumers = find_downstream_consumers(state=state, begin=edge) for consumer_node, feed_edge in consumers: - if feed_edge.data.num_elements() == intermediate_size: + assert self.map_parameter_compatible(map_exit_1.map, map_entry_2.map, state, sdfg) + # Tests if consumer consume less or equal the amount we generate. + if not producer_subset.covers(consumer_edge.data.src_subset): return None - if consumer_node is map_entry_2: # Dynamic map range + if consumer_node is map_entry_2: # Dynamic map range. return None if isinstance(consumer_node, nodes.LibraryNode): # TODO(phimuell): Allow some library nodes. return None else: # Ensure that there is no path that leads to the second map. - after_intermdiate_node = map_fusion_helper.all_nodes_between( + after_intermdiate_node = all_nodes_between( graph=state, begin=edge.dst, end=map_entry_2 ) if after_intermdiate_node is not None: @@ -568,111 +643,6 @@ def partition_first_outputs( return (pure_outputs, exclusive_outputs, shared_outputs) -@overload -def is_pointwise_subset( - subset: subsets.Range, - map_params: List[str], - param_association: Literal[False], -) -> bool: - ... - - -@overload -def is_pointwise_subset( - subset: subsets.Range, - map_params: List[str], - param_association: Literal[True], -) -> Optional[List[int]]: - ... - - -def is_pointwise_subset( - subset: subsets.Range, - map_params: List[str], - param_association: bool = False, -) -> bool: - """Tests if `subset` is "point wise" with respect to map parameters `map_params`. - - Essentially a subset is point wise, with respect to map parameters, if it access - the data in a `A[i, j]` manner. An example for a not point wise access would be - `A[i + 1, j]`. However, there are some special cases: - - All map parameters must be used, For example the expression `A[i, :]`, inside - the map `Map[i=0:N, j=0:M]` is not point wise, because `j` is not used. - - On the other hand if `A` is a 3D array then expressions such as `A[i, :, j]` - or `A[i, 3, j]` would be point wise. Although they are not a scalar. - - Furthermore, all parameters must appear exactly once, i.e. accesses such as - `A[i, i]`, even inside `Map[i=0:N]` is not point wise. - - It is important to realize that point wise is a very powerful property, since - it essentially releases us from the check of the order of the parameter. - However, there are some cases were it might fail. - - If the `param_association` argument is set to `True` the function will return the - parameter association, This is a list of integer, that indicates which parameter - was found in which dimension of the subset. - If the subset is point wise the function will return `None`. - - Args: - subset: The subset to inspect. - map_params: The list of parameters to inspect. - param_association: Return the parameter association. - """ - map_patterns = [re.compile(f"\\b{str(map_param)}\\b") for map_param in map_params] - subset_sizes = subset.size_exact() - unused_params = set(map_params) - parameter_to_dim_map: Dict[str, int] = dict() - - # Now go through each dimension of the subset and inspect them. - for dim in range(subset.dims()): - if(subset_sizes[dim] == 1): - # Only a single element is consumed, thus we must test if the access - # is done through a yet unused map parameter only. - ss_idx = str(subset[dim][0]) - for map_param, map_pattern in zip(map_params, map_patterns): - if(ss_idx == map_param): - # The map parameter is used alone without any additions. - if(map_param not in unused_params): - # The map parameter was already used, so we have something - # like `A[i, i]`. Thus it is not point wise! - return None if param_association else False - - # The parameter is used alone, so this is point wise. - unused_params.discard(map_param) - parameter_to_dim_map[map_param] = dim - break - - elif(map_pattern.match(ss_idx)): - # The parameter matches partially, e.g. `A[i + 1]`, and is not point wise - return None if param_association else False - - # If we here then `ss_idx` did not depend in any way on the map parameters. - # This is the case if it is a literal or an other symbol, but we know that - # it is constant (because of how symbols work). If it is really point wise - # depends on if all symbols are consumed. - - elif(subset_sizes[dim] == 0): - # This is a strange case that we ignore but it does not violate point wise. - pass - - else: - # There are multiple elements that are consumed. An example would be - # expressions such as `A[i, :, j]` again for a 2D Map. For now we allow - # them, but it is a bit dangerous to do this because it only works if - # the other map also processed that that with that expression. - # This is a fair assumption. - for ss_element in map(str, subset[dim]): - if any(map_pattern.match(ss) for ss in ss_element): - return None if param_association else False - - # Not all parameters were used, so it is not point wise - if(len(unused_params) != 0): - return None if param_association else False - - if(param_association): - return [parameter_to_dim_map[map_param] for map_param in map_params] - return True - - def is_nested_sdfg( sdfg: Union[dace.SDFG, dace.SDFGState, nodes.NestedSDFG], ) -> bool: From 9e36447487262e7aef22b27b2223c412af9b8c98 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 23 Aug 2024 14:32:21 +0200 Subject: [PATCH 005/115] Fixed a bug in the map fusion. When the function was fixing the innteriour of the second map, it did not remove the readiong. --- dace/transformation/dataflow/map_fusion_serial.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dace/transformation/dataflow/map_fusion_serial.py b/dace/transformation/dataflow/map_fusion_serial.py index 3649d1a335..f80ca21240 100644 --- a/dace/transformation/dataflow/map_fusion_serial.py +++ b/dace/transformation/dataflow/map_fusion_serial.py @@ -402,6 +402,7 @@ def handle_intermediate_set( consumer_edge = consumer_tree.edge assert consumer_edge.data.data == inter_name consumer_edge.data.data = new_inter_name + consumer_edge.data.replace(memlet_repl) if is_scalar: consumer_edge.data.src_subset = "0" elif consumer_edge.data.subset is not None: From 7a48e0dfaefd86d07dc71f8229e5eed0394856f3 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 23 Aug 2024 15:14:22 +0200 Subject: [PATCH 006/115] Made some formating changes. --- .../dataflow/map_fusion_helper.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index 6d5ff355dd..d52c96dbdd 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -256,7 +256,7 @@ def map_parameter_compatible( return True @staticmethod - def find_perameter_remapping( + def find_parameter_remapping( first_map: nodes.Map, second_map: nodes.Map ) -> Union[Dict[str, str], None]: @@ -289,8 +289,8 @@ def find_perameter_remapping( # The ranges, however, we apply some post processing to them. simp = lambda e: symbolic.simplify_ext(symbolic.simplify(e)) - first_rngs: List[Tuple[Any, Any, Any]] = [ (simp(rng[0]), simp(rng[1]), simp(rng[2])) for rng in first_map.range] - second_rngs: List[Tuple[Any, Any, Any]] = [ (simp(rng[0]), simp(rng[1]), simp(rng[2])) for rng in second_map.range] + first_rngs: List[Tuple[Any, Any, Any]] = [tuple(simp(r) for r in rng) for rng in first_map.range] + second_rngs: List[Tuple[Any, Any, Any]] = [tuple(simp(r) for r in rng) for rng in second_map.range] # These are the parameters of the second map that have not yet associated to # a parameter of the second map. @@ -533,9 +533,7 @@ def partition_first_outputs( # TODO(phimuell): Handle this case properly. # The main reason why we forbid this is because it becomes a bit tricky # to figuring out the size of the intermediate. - inner_collector_edges = list( - state.in_edges_by_connector(intermediate_node, "IN_" + out_edge.src_conn[3:]) - ) + inner_collector_edges = list(state.in_edges_by_connector(intermediate_node, "IN_" + out_edge.src_conn[3:])) if len(inner_collector_edges) > 1: return None @@ -582,7 +580,6 @@ def partition_first_outputs( assert self.map_parameter_compatible(map_exit_1.map, map_entry_2.map, state, sdfg) # Tests if consumer consume less or equal the amount we generate. if not producer_subset.covers(consumer_edge.data.src_subset): - print(f"NOT COVER | {producer_subset} | {consumer_edge.data.src_subset}") return None if consumer_node is map_entry_2: # Dynamic map range. return None @@ -627,9 +624,7 @@ def partition_first_outputs( return None else: # Ensure that there is no path that leads to the second map. - after_intermdiate_node = all_nodes_between( - graph=state, begin=edge.dst, end=map_entry_2 - ) + after_intermdiate_node = all_nodes_between(graph=state, begin=edge.dst, end=map_entry_2) if after_intermdiate_node is not None: return None # If we are here, then we know that the node is a shared output @@ -637,9 +632,7 @@ def partition_first_outputs( continue assert exclusive_outputs or shared_outputs or pure_outputs - assert len(processed_inter_nodes) == sum( - len(x) for x in [pure_outputs, exclusive_outputs, shared_outputs] - ) + assert len(processed_inter_nodes) == sum(len(x) for x in [pure_outputs, exclusive_outputs, shared_outputs]) return (pure_outputs, exclusive_outputs, shared_outputs) From d60904530457e83511dff167924bd9f1004d22cf Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 23 Aug 2024 15:19:14 +0200 Subject: [PATCH 007/115] Updated the tests of the map fusion. It almost passes all fuction. However, the one that needs renaming are not yet done. --- tests/transformations/mapfusion_test.py | 136 +++++++++++++++++------- 1 file changed, 99 insertions(+), 37 deletions(-) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 724c8c97ee..cb49dc32be 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -1,12 +1,81 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +from typing import Any, Union + import numpy as np import os import dace -from dace.transformation.dataflow import MapFusion + +from dace import SDFG, SDFGState +from dace.sdfg import nodes +from dace.transformation.dataflow import MapFusion, MapFusionOriginal + + +def count_node(sdfg: SDFG, node_type): + nb_nodes = 0 + for rsdfg in sdfg.all_sdfgs_recursive(): + for state in sdfg.states(): + for node in state.nodes(): + if isinstance(node, node_type): + nb_nodes += 1 + return nb_nodes + +def apply_fusion( + sdfg: SDFG, + removed_maps: Union[int, None] = None, + final_maps: Union[int, None] = None, +) -> SDFG: + """Applies the Map fusion transformation. + + The function checks that the number of maps has been reduced, it is also possible + to specify the number of removed maps. It is also possible to specify the final + number of maps. + """ + num_maps_before = count_node(sdfg, nodes.MapEntry) + sdfg.apply_transformations_repeated(MapFusion, validate=True, validate_all=True) + num_maps_after = count_node(sdfg, nodes.MapEntry) + + has_processed = False + if removed_maps is not None: + has_processed = True + rm = num_maps_before - num_maps_after + assert rm == removed_maps, f"Expected to remove {removed_maps} but removed {rm}" + if final_maps is not None: + has_processed = True + assert final_maps == num_maps_after, f"Expected that only {final_maps} maps remain, but there are sill {num_maps_after}." + if not has_processed: + assert num_maps_after < num_maps_before, f"Maps after: {num_maps_after}; Maps before: {num_maps_before}" + return sdfg @dace.program -def fusion(A: dace.float32[10, 20], B: dace.float32[10, 20], out: dace.float32[1]): +def fusion_simple(A: dace.float32[10, 20], B: dace.float32[10, 20], out: dace.float32[1]): + tmp = dace.define_local([10, 20], dtype=A.dtype) + tmp_2 = dace.define_local([10, 20], dtype=A.dtype) + for i, j in dace.map[0:10, 0:20]: + with dace.tasklet: + a << A[i, j] + b >> tmp[i, j] + + b = a * a + + for i, j in dace.map[0:10, 0:20]: + with dace.tasklet: + a << tmp[i, j] + b << B[i, j] + c >> tmp_2[i, j] + + c = a + b + + for i, j in dace.map[0:10, 0:20]: + with dace.tasklet: + a << tmp_2[i, j] + b >> out(1, lambda a, b: a + b)[0] + + b = a + + +@dace.program +def fusion_rename(A: dace.float32[10, 20], B: dace.float32[10, 20], out: dace.float32[1]): tmp = dace.define_local([10, 20], dtype=A.dtype) tmp_2 = dace.define_local([10, 20], dtype=A.dtype) for i, j in dace.map[0:10, 0:20]: @@ -65,13 +134,22 @@ def fusion_chain(A: dace.float32[10, 20], B: dace.float32[10, 20]): tmp2 = tmp1 * 4 B[:] = tmp2 + 5 - def test_fusion_simple(): - sdfg = fusion.to_sdfg() - sdfg.save(os.path.join('_dacegraphs', 'before1.sdfg')) - sdfg.simplify() - sdfg.apply_transformations_repeated(MapFusion) - sdfg.save(os.path.join('_dacegraphs', 'after1.sdfg')) + sdfg = fusion_simple.to_sdfg() + sdfg = apply_fusion(sdfg, final_maps=1) + + A = np.random.rand(10, 20).astype(np.float32) + B = np.random.rand(10, 20).astype(np.float32) + out = np.zeros(shape=1, dtype=np.float32) + sdfg(A=A, B=B, out=out) + + diff = abs(np.sum(A * A + B) - out) + print('Difference:', diff) + assert diff <= 1e-3 + +def test_fusion_rename(): + sdfg = fusion_rename.to_sdfg() + sdfg = apply_fusion(sdfg, final_maps=1) A = np.random.rand(10, 20).astype(np.float32) B = np.random.rand(10, 20).astype(np.float32) @@ -85,18 +163,10 @@ def test_fusion_simple(): def test_multiple_fusions(): sdfg = multiple_fusions.to_sdfg() - num_nodes_before = len([node for state in sdfg.nodes() for node in state.nodes()]) sdfg.save(os.path.join('_dacegraphs', 'before2.sdfg')) sdfg.simplify() - sdfg.apply_transformations_repeated(MapFusion) - sdfg.save(os.path.join('_dacegraphs', 'after2.sdfg')) - - num_nodes_after = len([node for state in sdfg.nodes() for node in state.nodes()]) - # Ensure that the number of nodes was reduced after transformation - if num_nodes_after >= num_nodes_before: - raise RuntimeError('SDFG was not properly transformed ' - '(nodes before: %d, after: %d)' % (num_nodes_before, num_nodes_after)) + sdfg = apply_fusion(sdfg) A = np.random.rand(10, 20).astype(np.float32) B = np.zeros_like(A) @@ -114,19 +184,8 @@ def test_multiple_fusions(): def test_fusion_chain(): sdfg = fusion_chain.to_sdfg() - sdfg.save(os.path.join('_dacegraphs', 'before3.sdfg')) sdfg.simplify() - sdfg.apply_transformations(MapFusion) - num_nodes_before = len([node for state in sdfg.nodes() for node in state.nodes()]) - sdfg.apply_transformations(MapFusion) - sdfg.apply_transformations(MapFusion) - sdfg.save(os.path.join('_dacegraphs', 'after3.sdfg')) - - num_nodes_after = len([node for state in sdfg.nodes() for node in state.nodes()]) - # Ensure that the number of nodes was reduced after transformation - if num_nodes_after >= num_nodes_before: - raise RuntimeError('SDFG was not properly transformed ' - '(nodes before: %d, after: %d)' % (num_nodes_before, num_nodes_after)) + sdfg = apply_fusion(sdfg, final_maps=1) A = np.random.rand(10, 20).astype(np.float32) B = np.zeros_like(A) @@ -158,7 +217,8 @@ def test_fusion_with_transient(): expected = A * A * 2 sdfg = fusion_with_transient.to_sdfg() sdfg.simplify() - sdfg.apply_transformations(MapFusion) + sdfg = apply_fusion(sdfg, removed_maps=2) + sdfg(A=A) assert np.allclose(A, expected) @@ -191,7 +251,7 @@ def build_sdfg(): return sdfg sdfg = build_sdfg() - sdfg.apply_transformations(MapFusion) + sdfg = apply_fusion(sdfg) A = np.random.rand(N, K) B = np.repeat(np.nan, N) @@ -217,10 +277,12 @@ def inverted_maps(A: dace.int32[10]): sdfg(A=val0) assert np.array_equal(val0, ref) - sdfg.apply_transformations(MapFusion) + # This can not be fused + apply_fusion(sdfg, removed_maps=0) + val1 = np.ndarray((10,), dtype=np.int32) sdfg(A=val1) - assert np.array_equal(val1, ref) + assert np.array_equal(val1, ref), f"REF: {ref}; VAL: {val1}" def test_fusion_with_empty_memlet(): @@ -240,8 +302,7 @@ def inner_product(A: dace.float32[N], B: dace.float32[N], out: dace.float32[1]): out[0] += lsum sdfg = inner_product.to_sdfg(simplify=True) - count = sdfg.apply_transformations_repeated(MapFusion) - assert count == 2 + apply_fusion(sdfg, removed_maps=2) A = np.arange(1024, dtype=np.float32) B = np.arange(1024, dtype=np.float32) @@ -265,7 +326,7 @@ def fusion_with_nested_sdfg_0(A: dace.int32[10], B: dace.int32[10], C: dace.int3 A[i] = tmp[i] * 2 sdfg = fusion_with_nested_sdfg_0.to_sdfg(simplify=True) - sdfg.apply_transformations(MapFusion) + apply_fusion(sdfg) for sd in sdfg.all_sdfgs_recursive(): if sd is not sdfg: @@ -295,7 +356,7 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3 B[i] = tmp[i] * 2 sdfg = fusion_with_nested_sdfg_1.to_sdfg(simplify=True) - sdfg.apply_transformations(MapFusion) + apply_fusion(sdfg) if len(sdfg.states()) != 1: return @@ -315,6 +376,7 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3 test_multiple_fusions() test_fusion_chain() test_fusion_with_transient() + test_fusion_rename() test_fusion_with_transient_scalar() test_fusion_with_inverted_indices() test_fusion_with_empty_memlet() From 52c4542eaed0ae94cd78954aad346968ed78929a Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 23 Aug 2024 16:07:25 +0200 Subject: [PATCH 008/115] WIP: Started with a renamer function. --- .../dataflow/map_fusion_helper.py | 53 +++++++++++++++++-- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index d52c96dbdd..ac83796fb7 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -266,10 +266,10 @@ def find_parameter_remapping( names of the first map (values). Because of how the replace function works the `dict` describes how to replace the parameters of the second map with parameters of the first map. - - If the remapping does not exists, the function will return `None`. - Parameters, that already have the correct names, will not be included in the - final mapping. + Parameters that already have the correct name and compatible range, are not + included in the return value, thus the keys and values are always different. + If no renaming at all is _needed_ then the function returns an empty `dict`. + If no remapping exists, then the function will return `None`. Args: first_map: The first map (these parameters will be replaced). @@ -282,6 +282,8 @@ def find_parameter_remapping( if len(first_params) != len(second_params): return None + if len(first_params) == 0: # Trivial maps + return {} # Allows to map a name to the corresponding index first_pmap: Dict[str, int] = {map_param: i for i, map_param in enumerate(first_params)} @@ -339,6 +341,49 @@ def find_parameter_remapping( assert len(unused_first_params) == 0 return final_mapping + def rename_map_parameters( + self, + first_map: nodes.Map, + second_map: nodes.Map, + ) -> None: + """Replaces the map parameters of the second map with names from the first. + + The replacement is done in a safe way, thus `{'i': 'j', 'j': 'i'}` is + handled correct. The function assumes that a proper replacement exists. + The replacement is computed by calling `self.find_parameter_remapping()`. + + Args: + first_map: The first map (these parameters will be replaced). + second_map: The second map, these parameters acts as source. + """ + + # Compute the replacement dict. + repl: Dict[str, str] = self.find_parameter_remapping(first_map=first_map, second_map=second_map) + + if repl is None: + raise RuntimeError("The replacement does not exist") + if len(repl) == 0: + return + + # Why is this thing is symbolic and not in replace? + symbolic.safe_replace( + mapping=repl, + replace_callback= + + + # Parameters are not replaced in the map even they are included + + + + + + + + + + + + def is_interstate_transient( self, transient: Union[str, nodes.AccessNode], From 3b758bf6fc79af195db1210924cd35e828bd50a4 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 28 Aug 2024 11:15:45 +0200 Subject: [PATCH 009/115] Continued with the parallel fusion stuff. --- dace/transformation/dataflow/__init__.py | 2 +- dace/transformation/dataflow/map_fusion.py | 1 - .../dataflow/map_fusion_helper.py | 318 +++++++++--------- .../dataflow/map_fusion_parallel.py | 14 +- .../dataflow/map_fusion_serial.py | 17 +- 5 files changed, 182 insertions(+), 170 deletions(-) diff --git a/dace/transformation/dataflow/__init__.py b/dace/transformation/dataflow/__init__.py index 9316949d70..dbd3838d9f 100644 --- a/dace/transformation/dataflow/__init__.py +++ b/dace/transformation/dataflow/__init__.py @@ -8,7 +8,7 @@ from .map_for_loop import MapToForLoop, MapToForLoopRegion from .map_interchange import MapInterchange from .map_dim_shuffle import MapDimShuffle -from .map_fusion import MapFusion, ParallelMapFusion, SerialMapFusion, MapFusionOriginal +from .map_fusion import MapFusion, ParallelMapFusion, SerialMapFusion from .map_fission import MapFission from .map_unroll import MapUnroll from .trivial_map_elimination import TrivialMapElimination diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index c0b458665e..3735d3e7dc 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -3,7 +3,6 @@ from .map_fusion_serial import SerialMapFusion from .map_fusion_parallel import ParallelMapFusion -from .map_fusion_original import MapFusionOriginal # Compatibility with previous versions of DaCe and clients. MapFusion = SerialMapFusion diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index ac83796fb7..db1c156408 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -5,13 +5,16 @@ import functools import itertools import re -from typing import Any, Dict, Iterable, List, Literal, Optional, Set, Sequence, Tuple, Union, overload +import copy +from typing import Any, Dict, Iterable, List, Literal, Optional, Set, Sequence, Tuple, Union, overload, TypeVar import dace from dace import data, properties, subsets, transformation, symbolic -from dace.sdfg import SDFG, SDFGState, graph, nodes, validation +from dace.sdfg import SDFG, SDFGState, graph, nodes, validation, replace from dace.transformation import helpers +_T = TypeVar("_T") + @properties.make_properties class MapFusionHelper(transformation.SingleStateTransformation): """Contains common part of the fusion for parallel and serial Map fusion. @@ -19,7 +22,7 @@ class MapFusionHelper(transformation.SingleStateTransformation): The transformation assumes that the SDFG obeys the principals outlined [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). The main advantage of this structure is, that it is rather easy to determine if a transient is used anywhere else. This check, performed by - `is_interstate_transient()`. It is further speeded up by cashing some computation, + `is_shared_data()`. It is further speeded up by cashing some computation, thus such an object should not be used after interstate optimizations were applied to the SDFG. @@ -40,13 +43,13 @@ class MapFusionHelper(transformation.SingleStateTransformation): allow_none=False, desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", ) - shared_transients = properties.DictProperty( + shared_data = properties.DictProperty( key_type=SDFG, value_type=set, #[str] default=None, allow_none=True, - desc="Maps SDFGs to the set of array transients that can not be removed. " - "The variable acts as a cache, and is managed by 'is_interstate_transient()'.", + desc="Maps SDFGs to the set of data that can not be removed. " + "The variable acts as a cache, and is managed by 'is_shared_data()'.", ) def __init__( @@ -60,12 +63,14 @@ def __init__( self.only_toplevel_maps = bool(only_toplevel_maps) if only_inner_maps is not None: self.only_inner_maps = bool(only_inner_maps) - self.shared_transients = {} + self.shared_data = {} + @classmethod def expressions(cls) -> bool: raise RuntimeError("The `_MapFusionHelper` is not a transformation on its own.") + def can_be_fused( self, map_entry_1: nodes.MapEntry, @@ -113,14 +118,13 @@ def can_be_fused( elif is_nested_sdfg(sdfg): return False - # We will now check if there exists a "remapping" that we can use. - if not self.map_parameter_compatible( - map_1=map_entry_1.map, map_2=map_entry_2.map, state=graph, sdfg=sdfg - ): + # We will now check if there exists a remapping that of the map parameter + if self.find_parameter_remapping(first_map=map_entry_1.map, second_map=map_entry_2.map) is None: return False return True + @staticmethod def relocate_nodes( from_node: Union[nodes.MapExit, nodes.MapEntry], @@ -216,44 +220,6 @@ def relocate_nodes( assert len(from_node.in_connectors) == 0 assert len(from_node.out_connectors) == 0 - @staticmethod - def map_parameter_compatible( - map_1: nodes.Map, - map_2: nodes.Map, - state: Union[SDFGState, SDFG], - sdfg: SDFG, - ) -> bool: - """Checks if the parameters of `map_1` are compatible with `map_2`. - - The check follows the following rules: - - The names of the map variables must be the same, i.e. no renaming - is performed. - - The ranges must be the same. - """ - range_1: subsets.Range = map_1.range - params_1: Sequence[str] = map_1.params - range_2: subsets.Range = map_2.range - params_2: Sequence[str] = map_2.params - - # The maps are only fuseable if we have an exact match in the parameter names - # this is because we do not do any renaming. This is in accordance with the - # rules. - if set(params_1) != set(params_2): - return False - - # Maps the name of a parameter to the dimension index - param_dim_map_1: Dict[str, int] = {pname: i for i, pname in enumerate(params_1)} - param_dim_map_2: Dict[str, int] = {pname: i for i, pname in enumerate(params_2)} - - # To fuse the two maps the ranges must have the same ranges - for pname in params_1: - idx_1 = param_dim_map_1[pname] - idx_2 = param_dim_map_2[pname] - # TODO(phimuell): do we need to call simplify? - if range_1[idx_1] != range_2[idx_2]: - return False - - return True @staticmethod def find_parameter_remapping( @@ -268,7 +234,8 @@ def find_parameter_remapping( with parameters of the first map. Parameters that already have the correct name and compatible range, are not included in the return value, thus the keys and values are always different. - If no renaming at all is _needed_ then the function returns an empty `dict`. + If no renaming at is _needed_, i.e. all parameter have the same name and range, + then the function returns an empty `dict`. If no remapping exists, then the function will return `None`. Args: @@ -282,20 +249,20 @@ def find_parameter_remapping( if len(first_params) != len(second_params): return None - if len(first_params) == 0: # Trivial maps - return {} - - # Allows to map a name to the corresponding index - first_pmap: Dict[str, int] = {map_param: i for i, map_param in enumerate(first_params)} - second_pmap: Dict[str, int] = {map_param: i for i, map_param in enumerate(second_map)} # The ranges, however, we apply some post processing to them. simp = lambda e: symbolic.simplify_ext(symbolic.simplify(e)) - first_rngs: List[Tuple[Any, Any, Any]] = [tuple(simp(r) for r in rng) for rng in first_map.range] - second_rngs: List[Tuple[Any, Any, Any]] = [tuple(simp(r) for r in rng) for rng in second_map.range] - - # These are the parameters of the second map that have not yet associated to - # a parameter of the second map. + first_rngs: Dict[str, Tuple[Any, Any, Any]] = { + param: tuple(simp(r) for r in rng) + for param, rng in zip(first_params, first_map.range) + } + second_rngs: Dict[Tuple[Any, Any, Any]] = { + param: tuple(simp(r) for r in rng) + for param, rng in zip(second_params, second_map.range) + } + + # Parameters of the second map that have not yet been matched to a parameter + # of the first map and vice versa. unmapped_second_params: Set[str] = set(second_params) unused_first_params: Set[str] = set(first_params) @@ -303,18 +270,17 @@ def find_parameter_remapping( # is needed then the parameter is not present in the mapping. final_mapping: Dict[str, str] = {} - # First we check if a remapping is needed at all, for this we look at the - # parameter that are the same in both maps. + # First we identify the parameters that already have the correct name. for param in set(first_params).intersection(second_params): - first_rng = first_rngs[first_pmap[param]] - second_rng = second_rngs[second_pmap[param]] + first_rng = first_rngs[param] + second_rng = second_rngs[param] if first_rng == second_rng: # They have the same name and the same range, this is already a match. # Because the names are already the same, we do not have to enter them # in the `final_mapping` - unmapped_second_params.pop(param) - unused_first_params.pop(param) + unmapped_second_params.discard(param) + unused_first_params.discard(param) # Check if no remapping is needed. if len(unmapped_second_params) == 0: @@ -323,28 +289,31 @@ def find_parameter_remapping( # Now we go through all the parameters that we have not mapped yet. # All of them will result in a remapping. for unmapped_second_param in unmapped_second_params: - second_rng = second_rngs[second_pmap[unmapped_second_param]] + second_rng = second_rngs[unmapped_second_param] assert unmapped_second_param not in final_mapping # Now look in all not yet used parameters of the first map which to use. for candidate_param in unused_first_params: - candidate_rng = first_rngs[first_pmap[candidate_param]] + candidate_rng = first_rngs[candidate_param] if candidate_rng == second_rng: final_mapping[unmapped_second_param] = candidate_param + unused_first_params.discard(candidate_param) break else: # We did not find a candidate, so the remapping does not exist return None - unused_first_params.pop(final_mapping[unmapped_second_param]) - assert len(unused_first_params) == 0 + assert len(final_mapping) == len(first_params) return final_mapping + def rename_map_parameters( self, first_map: nodes.Map, second_map: nodes.Map, + second_map_entry: nodes.MapEntry, + state: SDFGState, ) -> None: """Replaces the map parameters of the second map with names from the first. @@ -353,112 +322,118 @@ def rename_map_parameters( The replacement is computed by calling `self.find_parameter_remapping()`. Args: - first_map: The first map (these parameters will be replaced). - second_map: The second map, these parameters acts as source. + first_map: The first map (these are the final parameter). + second_map: The second map, this map will be replaced. + second_map_entry: The entry node of the second map. + state: The SDFGState on which we operate. """ - # Compute the replacement dict. - repl: Dict[str, str] = self.find_parameter_remapping(first_map=first_map, second_map=second_map) + repl_dict: Dict[str, str] = self.find_parameter_remapping(first_map=first_map, second_map=second_map) - if repl is None: + if repl_dict is None: raise RuntimeError("The replacement does not exist") - if len(repl) == 0: + if len(repl_dict) == 0: return + second_map_scope = state.scope_subgraph(entry_node=second_map_entry) # Why is this thing is symbolic and not in replace? symbolic.safe_replace( - mapping=repl, - replace_callback= - - - # Parameters are not replaced in the map even they are included - - - - - - - - + mapping=repl_dict, + replace_callback=second_map_scope.replace_dict, + ) + # For some odd reason the replace function does not modify the range and + # parameter of the map, so we will do it the hard way. + second_map.params = copy.deepcopy(first_map.params) + second_map.range = copy.deepcopy(first_map.range) - - def is_interstate_transient( + def is_shared_data( self, - transient: Union[str, nodes.AccessNode], + data: nodes.AccessNode, sdfg: dace.SDFG, - state: dace.SDFGState, ) -> bool: - """Tests if `transient` is an interstate transient, an can not be removed. + """Tests if `data` is interstate data, an can not be removed. + + Interstate data is used to transmit data between multiple state or + by extension within the state, and thus can not be removed by the + serial map fusion. - Essentially this function checks if a transient might be needed in a - different state in the SDFG, because it transmit information from - one state to the other. - If only the name of the data container is passed the function will - first look for an corresponding access node. + The function determine this properties, according to the following rules: + - The access node must be in the top scope. + - The underlying data is global. + - The `data` descriptor is used multiple times with the same state. + - `data` has an out or in degree of zero. + - The underlying data is referred to in another state. - The set of these "interstate transients" is computed once per SDFG. - The result is then cached internally for later reuse. + The function computes this information and then caches it for later use. Args: transient: The transient that should be checked. sdfg: The SDFG containing the array. - state: If given the state the node is located in. Note: - This function build upon the structure of the SDFG that is outlined - in the HackMD document. + - This function does not inspect the interstate edges, instead the + set of data that is accessed in interstate edges is approximated + with the set of sink nodes. + - This function works best if the SDFG uses SSA style. """ + if sdfg not in self.shared_data: + self._compute_shared_data(sdfg) + return data.data in self.shared_data[sdfg] - # According to [rule 6](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG) - # the set of such transients is partially given by all source access nodes. - # Because of rule 3 we also include all scalars in this set, as an over - # approximation. Furthermore, because simplify might violate rule 3, - # we also include the sink nodes. - # See if we have already computed the set - if sdfg in self.shared_transients: - shared_sdfg_transients: Set[str] = self.shared_transients[sdfg] - else: - # SDFG is not known so we have to compute the set. - shared_sdfg_transients = set() - for state_to_scan in sdfg.all_states(): - # TODO(phimuell): Use `all_nodes_recursive()` once it is available. - shared_sdfg_transients.update( - [ - node.data - for node in itertools.chain( - state_to_scan.source_nodes(), state_to_scan.sink_nodes() - ) - if isinstance(node, nodes.AccessNode) - and sdfg.arrays[node.data].transient - ] - ) - self.shared_transients[sdfg] = shared_sdfg_transients - - if isinstance(transient, str): - name = transient - matching_access_nodes = [node for node in state.data_nodes() if node.data == name] - # Rule 8: There is only one access node per state for data. - assert len(matching_access_nodes) == 1 - transient = matching_access_nodes[0] - else: - assert isinstance(transient, nodes.AccessNode) - name = transient.data + def _compute_shared_data( + self, + sdfg: dace.SDFG, + ) -> None: + """This function computes the set of shared data for SDFG `sdfg`. - desc: data.Data = sdfg.arrays[name] - if not desc.transient: - return True - if isinstance(desc, data.Scalar): - return True # Scalars can not be removed by fusion anyway. + See the documentation for `self.is_shared_data()` for a description. - # Rule 8: If degree larger than one then it is used within the state. - if state.out_degree(transient) > 1: - return True + Args: + sdfg: The SDFG for which the set of shared data should be computed. + """ + # Shared data of this SDFG. + shared_data: Set[str] = set() + + # Add all global data. + for data_name, data_desc in sdfg.arrays.items(): + if not data_desc.transient: + shared_data.add(data_name) + + # We go through all states and classify the nodes, according to the rules. + prevously_seen_data: Set[str] = set() + for state in sdfg.nodes(): + scope_dict = state.scope_dict() + for access_node in state.data_nodes(): + if scope_dict[access_node] is not None: + # We are only interested in global data. + pass + elif access_node.data in shared_data: + # The data was already determined to be shared data + pass + elif access_node.data in prevously_seen_data: + # We have seen this data before, either in this state or in + # a previous one, but we did not classifies it as shared, + # let's do this now. Note that we do not remove the data + # also from `previously_seen_data`. + shared_data.add(access_node.data) + elif state.out_degree(access_node) == 0: + # Sink and source nodes also have to be kept. + shared_data.add(access_node.data) + elif state.in_degree(access_node) == 0: + shared_data.add(access_node.data) + else: + # The node was not classified as shared data, so we record that + # we saw it. Note that a node that was immediately classified + # as shared node will never be added to this set, but a data + # that was found twice will be inside this list. + prevously_seen_data.add(access_node.data) + + # Update the internal cache + self.shared_data[sdfg] = shared_data - # Now we check if it is used in a different state. - return name in shared_sdfg_transients def partition_first_outputs( self, @@ -514,6 +489,13 @@ def partition_first_outputs( map_params_1: Sequence[str] = map_exit_1.map.params map_params_2: Sequence[str] = map_entry_2.map.params + # Compute the renaming that for translating the parameter of the _second_ + # map to the ones used by the first map. + repl_dict: Dict[str, str] = self.find_parameter_remapping( + first_map=map_exit_1.map, + second_map=map_entry_2.map, + ) + # Set of intermediate nodes that we have already processed. processed_inter_nodes: Set[nodes.Node] = set() @@ -603,8 +585,9 @@ def partition_first_outputs( if producer_edge.data.wcr is not None: return None - if len(downstream_nodes) == 0: + # TODO(phimuell): Refactor this if such that the loop is only there once. + # There is nothing between intermediate node and the entry of the # second map, thus the edge belongs either in `\mathbb{S}` or # `\mathbb{E}`. @@ -622,21 +605,28 @@ def partition_first_outputs( consumers = find_downstream_consumers(state=state, begin=intermediate_node) for consumer_node, consumer_edge in consumers: - assert self.map_parameter_compatible(map_exit_1.map, map_entry_2.map, state, sdfg) - # Tests if consumer consume less or equal the amount we generate. - if not producer_subset.covers(consumer_edge.data.src_subset): - return None if consumer_node is map_entry_2: # Dynamic map range. return None if isinstance(consumer_node, nodes.LibraryNode): # TODO(phimuell): Allow some library nodes. return None + # Tests if consumer consume less or equal the amount we generate. + # If the source of the consumer is not set, we can not check what it reads. + if consumer_edge.data.src_subset is None: + return None + # For that we have to perform a replace operation of the consumer . + mod_consumer_subset = copy.deepcopy(consumer_edge.data.src_subset) + symbolic.safe_replace(mapping=repl_dict, replace_callback=mod_consumer_subset.replace) + + if not producer_subset.covers(mod_consumer_subset): + return None + # Note that "remove" has a special meaning here, regardless of the # output of the check function, from within the second map we remove # the intermediate, it has more the meaning of "do we need to # reconstruct it after the second map again?" - if self.is_interstate_transient(intermediate_node, sdfg, state): + if self.is_shared_data(intermediate_node, sdfg): shared_outputs.add(out_edge) else: exclusive_outputs.add(out_edge) @@ -658,15 +648,21 @@ def partition_first_outputs( found_second_entry = True consumers = find_downstream_consumers(state=state, begin=edge) for consumer_node, feed_edge in consumers: - assert self.map_parameter_compatible(map_exit_1.map, map_entry_2.map, state, sdfg) - # Tests if consumer consume less or equal the amount we generate. - if not producer_subset.covers(consumer_edge.data.src_subset): - return None if consumer_node is map_entry_2: # Dynamic map range. return None if isinstance(consumer_node, nodes.LibraryNode): # TODO(phimuell): Allow some library nodes. return None + + # Tests if consumer consume less or equal the amount we generate. + # If the source of the consumer is not set, we can not check what it reads. + if consumer_edge.data.src_subset is None: + return None + # For that we have to perform a replace operation of the consumer . + mod_consumer_subset = copy.deepcopy(consumer_edge.data.src_subset) + symbolic.safe_replace(mapping=repl_dict, replace_callback=mod_consumer_subset.replace) + if not producer_subset.covers(mod_consumer_subset): + return None else: # Ensure that there is no path that leads to the second map. after_intermdiate_node = all_nodes_between(graph=state, begin=edge.dst, end=map_entry_2) @@ -858,7 +854,3 @@ def find_upstream_producers( only_tasklets=only_tasklets, reverse=True, ) - - - - diff --git a/dace/transformation/dataflow/map_fusion_parallel.py b/dace/transformation/dataflow/map_fusion_parallel.py index 0c032cc5f2..eae550c5e3 100644 --- a/dace/transformation/dataflow/map_fusion_parallel.py +++ b/dace/transformation/dataflow/map_fusion_parallel.py @@ -39,6 +39,7 @@ class ParallelMapFusion(map_fusion_helper.MapFusionHelper): desc="Only perform fusing if the Maps share a node as parent.", ) + def __init__( self, only_if_common_ancestor: Optional[bool] = None, @@ -48,6 +49,7 @@ def __init__( self.only_if_common_ancestor = only_if_common_ancestor super().__init__(**kwargs) + @classmethod def expressions(cls) -> Any: # This just matches _any_ two Maps inside a state. @@ -55,6 +57,7 @@ def expressions(cls) -> Any: state.add_nodes_from([cls.map_entry1, cls.map_entry2]) return [state] + def can_be_applied( self, graph: Union[SDFGState, SDFG], @@ -67,7 +70,7 @@ def can_be_applied( map_entry_2: nodes.MapEntry = self.map_entry2 # Check the structural properties of the maps, this will also ensure that - # the two maps are in the same scope. + # the two maps are in the same scope and the parameters can be renamed if not self.can_be_fused( map_entry_1=map_entry_1, map_entry_2=map_entry_2, @@ -92,6 +95,7 @@ def can_be_applied( return True + def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: """Performs the Map fusing. @@ -106,6 +110,14 @@ def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: map_entry_2: nodes.MapEntry = self.map_entry2 map_exit_2: nodes.MapExit = graph.exit_node(map_entry_2) + # Before we do anything we perform the renaming. + self.rename_map_parameters( + first_map=map_entry_1.map, + second_map=map_entry_2.map, + second_map_entry=map_entry_2, + state=graph, + ) + for to_node, from_node in zip((map_entry_1, map_exit_1), (map_entry_2, map_exit_2)): self.relocate_nodes( from_node=from_node, diff --git a/dace/transformation/dataflow/map_fusion_serial.py b/dace/transformation/dataflow/map_fusion_serial.py index f80ca21240..6667e0cc53 100644 --- a/dace/transformation/dataflow/map_fusion_serial.py +++ b/dace/transformation/dataflow/map_fusion_serial.py @@ -40,6 +40,7 @@ class SerialMapFusion(map_fusion_helper.MapFusionHelper): Notes: - This transformation modifies more nodes than it matches! + - Run simplify to get ri of excess keep alive nodes """ map_exit1 = transformation.transformation.PatternNode(nodes.MapExit) @@ -52,6 +53,7 @@ def __init__( ) -> None: super().__init__(**kwargs) + @classmethod def expressions(cls) -> Any: """Get the match expression. @@ -64,6 +66,7 @@ def expressions(cls) -> Any: """ return [dace.sdfg.utils.node_path_graph(cls.map_exit1, cls.access_node, cls.map_entry2)] + def can_be_applied( self, graph: Union[SDFGState, SDFG], @@ -105,6 +108,7 @@ def can_be_applied( return False return True + def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: """Performs the serial Map fusing. @@ -124,13 +128,20 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non assert isinstance(graph, dace.SDFGState) assert isinstance(self.map_exit1, nodes.MapExit) assert isinstance(self.map_entry2, nodes.MapEntry) - assert self.map_parameter_compatible(self.map_exit1.map, self.map_entry2.map, graph, sdfg) map_exit_1: nodes.MapExit = self.map_exit1 map_entry_2: nodes.MapEntry = self.map_entry2 map_exit_2: nodes.MapExit = graph.exit_node(self.map_entry2) map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit1) + # Before we do anything we perform the renaming. + self.rename_map_parameters( + first_map=map_exit_1.map, + second_map=map_entry_2.map, + second_map_entry=map_entry_2, + state=graph, + ) + output_partition = self.partition_first_outputs( state=graph, sdfg=sdfg, @@ -186,6 +197,7 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non # Now turn the second output node into the output node of the first Map. map_exit_2.map = map_entry_1.map + @staticmethod def handle_intermediate_set( intermediate_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]], @@ -215,9 +227,6 @@ def handle_intermediate_set( Notes: Before the transformation the `state` does not have to be valid and after this function has run the state is (most likely) invalid. - - Todo: - Rewrite using `MemletTree`. """ # Essentially this function removes the AccessNode between the two maps. From 377b428d241194ae10d7fca3c2014710f80641a6 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 28 Aug 2024 14:41:06 +0200 Subject: [PATCH 010/115] The fusion transformation now also checks if there is a write conflict in the input and output set. However, it is very simple. --- .../dataflow/map_fusion_helper.py | 101 ++++++++++++++++++ 1 file changed, 101 insertions(+) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index db1c156408..c2c0b3d862 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -122,6 +122,16 @@ def can_be_fused( if self.find_parameter_remapping(first_map=map_entry_1.map, second_map=map_entry_2.map) is None: return False + # Test for read-write conflicts + if self.has_read_write_dependency( + map_entry_1= + map_entry_1, + map_exit_2=graph.exit_node(map_entry_2), + state=graph, + sdfg=sdfg, + ): + return False + return True @@ -435,6 +445,56 @@ def _compute_shared_data( self.shared_data[sdfg] = shared_data + def has_read_write_dependency( + self, + map_entry_1: nodes.MapEntry, + map_exit_2: nodes.MapExit, + state: SDFGState, + sdfg: SDFG, + ) -> bool: + """Test if there is a read write dependency between the two maps. + + The function checks if the first map does not read anything from + a data descriptor, the second map writes into. + + Returns: + `True` if there is a conflict between input and outputs, `False` + if there is no conflict. + + Args: + map_entry_1: The entry node of the first map. + map_entry_2: The exit node of the second map. + state: The state on which we operate. + + Note: + The current implementation just computes the set of data that is + used as input and for output. If there is an intersection then + the function considers this as a read write conflict. + """ + + # Determine the set of data that is used as input and as output + inputs_data: Set[str] = set() + outputs_data: Set[str] = set() + for input_edge in state.in_edges(map_entry_1): + input_node: nodes.Node = input_edge.src + if not isinstance(input_node, nodes.AccessNode): + continue + input_desc: data.Data = input_node.desc(sdfg) + inputs_data.add(track_view(input_node, state, sdfg).data if isinstance(input_desc, data.View) else input_node.data) + for output_edge in state.out_edges(map_exit_2): + output_node: nodes.Node = output_edge.dst + if not isinstance(output_node, nodes.AccessNode): + continue + output_desc: data.Data = output_node.desc(sdfg) + outputs_data.add(track_view(output_node, state, sdfg).data if isinstance(output_desc, data.View) else output_node.data) + + # There is no intersection, thus they read and write to distinct data. + if not outputs_data.intersection(inputs_data): + return False + + return True + + def partition_first_outputs( self, state: SDFGState, @@ -854,3 +914,44 @@ def find_upstream_producers( only_tasklets=only_tasklets, reverse=True, ) + + +def track_view( + view: nodes.AccessNode, + state: SDFGState, + sdfg: SDFG, +) -> nodes.AccessNode: + """Find the original data of a View. + + Given the View `view`, the function will trace the view back to the + original access node. + + Args: + view: The view that should be traced. + state: The state in which we operate. + sdfg: The SDFG on which we operate. + """ + # First determine if the view is used for reading or writing. + assert isinstance(view.desc(sdfg), data.View) + curr_edge = dace.sdfg.utils.get_view_edge(state, view) + + if curr_edge is None: + raise RuntimeError(f"Failed to determine the direction of the view '{view}'.") + if curr_edge.dst_conn == "view": + # The view is used for reading. + next_node = lambda curr_edge: curr_edge.src + elif curr_edge.src_conn == "view": + # The view is used for writing. + next_node = lambda curr_edge: curr_edge.dst + else: + raise RuntimeError("Failed to determine the direction of the view '{view}'.") + + # Now trace the view back. + org_view = view + view = next_node(curr_edge) + while isinstance(view.desc(sdfg), data.View): + curr_edge = dace.sdfg.utils.get_view_edge(state, view) + if curr_edge is None: + raise RuntimeError(f"View tracing of '{org_view}' failed at note '{view}'.") + view = next_node(curr_edge) + return view From db4864b723928269f8461cbd5b890da88edfa991 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 28 Aug 2024 14:58:45 +0200 Subject: [PATCH 011/115] Updated some tests. --- .../mapfusion_data_races_test.py | 28 ++++++++++++++++++- tests/transformations/mapfusion_test.py | 8 +++++- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/tests/transformations/mapfusion_data_races_test.py b/tests/transformations/mapfusion_data_races_test.py index e765ec6978..13a71370e1 100644 --- a/tests/transformations/mapfusion_data_races_test.py +++ b/tests/transformations/mapfusion_data_races_test.py @@ -41,6 +41,13 @@ def rw_data_race_3(A: dace.float64[20], B: dace.float64[20]): A[:10] += 3.0 * offset(A[:11]) +@dace.program +def rw_data_race_4(A: dace.float64[20], B: dace.float64[20]): + # This is potentially fusable + A += B + A *= 2.0 + + def test_rw_data_race(): sdfg = rw_data_race.to_sdfg(simplify=True) sdfg.apply_transformations_repeated(MapFusion) @@ -50,8 +57,9 @@ def test_rw_data_race(): def test_rw_data_race_2_mf(): sdfg = rw_data_race_2.to_sdfg(simplify=True) - sdfg.apply_transformations_repeated(MapFusion) + nb_applied = sdfg.apply_transformations_repeated(MapFusion) map_entry_nodes = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, nodes.MapEntry)] + assert nb_applied > 0 assert (len(map_entry_nodes) > 1) @@ -69,8 +77,26 @@ def test_rw_data_race_3_sgf(): assert (len(map_entry_nodes) > 1) +def test_rw_data_race_3_mf(): + sdfg = rw_data_race_3.to_sdfg(simplify=True) + nb_applied = sdfg.apply_transformations_repeated(MapFusion) + map_entry_nodes = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, nodes.MapEntry)] + assert (len(map_entry_nodes) > 1) + assert nb_applied > 0 + + +def test_rw_data_race_4_mf(): + # It is technically possible to fuse it, because there is only a point wise dependency. + sdfg = rw_data_race_4.to_sdfg(simplify=True) + sdfg.apply_transformations_repeated(MapFusion) + map_entry_nodes = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, nodes.MapEntry)] + assert (len(map_entry_nodes) >= 1) + + if __name__ == "__main__": test_rw_data_race() test_rw_data_race_2_mf() test_rw_data_race_2_sgf() test_rw_data_race_3_sgf() + test_rw_data_race_3_mf() + test_rw_data_race_4_mf() diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index cb49dc32be..3e89b3e99b 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -38,11 +38,17 @@ def apply_fusion( if removed_maps is not None: has_processed = True rm = num_maps_before - num_maps_after + if not (rm == removed_maps): + sdfg.view() assert rm == removed_maps, f"Expected to remove {removed_maps} but removed {rm}" if final_maps is not None: has_processed = True + if not (final_maps == num_maps_after): + sdfg.view() assert final_maps == num_maps_after, f"Expected that only {final_maps} maps remain, but there are sill {num_maps_after}." if not has_processed: + if not (num_maps_after < num_maps_before): + sdfg.view() assert num_maps_after < num_maps_before, f"Maps after: {num_maps_after}; Maps before: {num_maps_before}" return sdfg @@ -372,11 +378,11 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3 if __name__ == '__main__': + test_fusion_rename() test_fusion_simple() test_multiple_fusions() test_fusion_chain() test_fusion_with_transient() - test_fusion_rename() test_fusion_with_transient_scalar() test_fusion_with_inverted_indices() test_fusion_with_empty_memlet() From f395acdb73ed986d6092b0398a34860e63af45bf Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 28 Aug 2024 15:17:14 +0200 Subject: [PATCH 012/115] Fixed an error. I shouild refactor that damn loop. --- dace/transformation/dataflow/map_fusion_helper.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index c2c0b3d862..cc37e246e7 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -700,6 +700,12 @@ def partition_first_outputs( # - The intermediate has a single connection to the second map, that # fulfills the restriction outlined above. # - All other connections have no connection to the second map. + + # We can fuse the maps only if the producer, i.e. the top map, + # represented by `producer_edge`, is gives us enough information. + producer_subset: subsets.Range = producer_edge.data.dst_subset + consumers = find_downstream_consumers(state=state, begin=intermediate_node) + found_second_entry = False for edge in state.out_edges(intermediate_node): if edge.dst is map_entry_2: @@ -707,7 +713,7 @@ def partition_first_outputs( return None found_second_entry = True consumers = find_downstream_consumers(state=state, begin=edge) - for consumer_node, feed_edge in consumers: + for consumer_node, consumer_edge in consumers: if consumer_node is map_entry_2: # Dynamic map range. return None if isinstance(consumer_node, nodes.LibraryNode): From b1ab95ed7946be82dbb4c00cd904cff9273787d1 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 28 Aug 2024 15:55:25 +0200 Subject: [PATCH 013/115] Some improvements to the tests. --- tests/transformations/apply_to_test.py | 9 +++++---- tests/transformations/warp_tiling_test.py | 20 ++++++++++++++------ 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/tests/transformations/apply_to_test.py b/tests/transformations/apply_to_test.py index f4cd832c3e..6bbc8ef8bf 100644 --- a/tests/transformations/apply_to_test.py +++ b/tests/transformations/apply_to_test.py @@ -2,6 +2,7 @@ """ Tests the `apply_to` transformation API. """ import dace from dace.sdfg import utils as sdutil +#from dace.transformation.dataflow import MapFusionOriginal as MapFusion from dace.transformation.dataflow import MapFusion from dace.transformation.subgraph import SubgraphFusion from dace.transformation.passes.pattern_matching import enumerate_matches @@ -31,7 +32,7 @@ def test_applyto_pattern(): transient = next(aname for aname, desc in sdfg.arrays.items() if desc.transient) access_node = next(n for n in state.nodes() if isinstance(n, dace.nodes.AccessNode) and n.data == transient) - MapFusion.apply_to(sdfg, first_map_exit=mult_exit, array=access_node, second_map_entry=add_entry) + MapFusion.apply_to(sdfg, map_exit1=mult_exit, access_node=access_node, map_entry2=add_entry) def test_applyto_enumerate(): @@ -42,9 +43,9 @@ def test_applyto_enumerate(): pattern = sdutil.node_path_graph(dace.nodes.MapExit, dace.nodes.AccessNode, dace.nodes.MapEntry) for subgraph in enumerate_matches(sdfg, pattern): MapFusion.apply_to(sdfg, - first_map_exit=subgraph.source_nodes()[0], - array=next(n for n in subgraph.nodes() if isinstance(n, dace.nodes.AccessNode)), - second_map_entry=subgraph.sink_nodes()[0]) + map_exit1=subgraph.source_nodes()[0], + access_node=next(n for n in subgraph.nodes() if isinstance(n, dace.nodes.AccessNode)), + map_entry2=subgraph.sink_nodes()[0]) def test_applyto_subgraph(): diff --git a/tests/transformations/warp_tiling_test.py b/tests/transformations/warp_tiling_test.py index 7c75d08878..4ca424a1c1 100644 --- a/tests/transformations/warp_tiling_test.py +++ b/tests/transformations/warp_tiling_test.py @@ -38,19 +38,27 @@ def test_warp_softmax(vector_length=1): sdfg = softmax_fwd.to_sdfg(simplify=True) # Apply transformations - sdfg.apply_transformations_repeated(ReduceExpansion) + sdfg.apply_transformations_repeated(ReduceExpansion, validate_all=True) MultiExpansion.apply_to(sdfg, sdfg.node(0).nodes()) SubgraphFusion.apply_to(sdfg, sdfg.node(0).nodes()) sdfg.expand_library_nodes() sdfg.simplify() - sdfg.apply_transformations_repeated([TrivialMapElimination, MapFusion]) - sdfg.apply_transformations(GPUTransformSDFG) + sdfg.apply_transformations_repeated([TrivialMapElimination, MapFusion], validate_all=True) + sdfg.apply_transformations(GPUTransformSDFG, validate_all=True) assert sdfg.apply_transformations(WarpTiling) == 1 - sdfg.apply_transformations_repeated([HoistState, InlineSDFG, StateFusion]) - sdfg.apply_transformations_repeated([TrivialMapElimination, MapFusion]) + sdfg.apply_transformations_repeated([HoistState, InlineSDFG, StateFusion], validate_all=True) + sdfg.apply_transformations_repeated([TrivialMapElimination, MapFusion], validate_all=True) if vector_length != 1: sdfg.apply_transformations_repeated( - Vectorization, dict(vector_len=vector_length, preamble=False, postamble=False, strided_map=False)) + Vectorization, + dict( + vector_len=vector_length, + preamble=False, + postamble=False, + strided_map=False + ), + validate_all=True + ) sdfg.specialize(dict(dn1=2, dn2=16, dn3=128, dr=128)) # Check validity From 945ca8f661aeac6a45da9fe08e0fefefb430430f Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 29 Aug 2024 07:42:01 +0200 Subject: [PATCH 014/115] Removed some debugging stuff. --- tests/transformations/apply_to_test.py | 1 - tests/transformations/mapfusion_test.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/transformations/apply_to_test.py b/tests/transformations/apply_to_test.py index 6bbc8ef8bf..2b1828f5ba 100644 --- a/tests/transformations/apply_to_test.py +++ b/tests/transformations/apply_to_test.py @@ -2,7 +2,6 @@ """ Tests the `apply_to` transformation API. """ import dace from dace.sdfg import utils as sdutil -#from dace.transformation.dataflow import MapFusionOriginal as MapFusion from dace.transformation.dataflow import MapFusion from dace.transformation.subgraph import SubgraphFusion from dace.transformation.passes.pattern_matching import enumerate_matches diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 3e89b3e99b..8055afee17 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -7,7 +7,7 @@ from dace import SDFG, SDFGState from dace.sdfg import nodes -from dace.transformation.dataflow import MapFusion, MapFusionOriginal +from dace.transformation.dataflow import MapFusion def count_node(sdfg: SDFG, node_type): From 940b9b6752d3e1f89bc9a801d431449758f0b357 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 29 Aug 2024 07:49:53 +0200 Subject: [PATCH 015/115] Fixed some typing stuff. --- dace/transformation/dataflow/map_fusion_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index cc37e246e7..f99def22bf 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -6,7 +6,7 @@ import itertools import re import copy -from typing import Any, Dict, Iterable, List, Literal, Optional, Set, Sequence, Tuple, Union, overload, TypeVar +from typing import Any, Dict, Iterable, List, Optional, Set, Sequence, TypeVar, Tuple, Union, overload import dace from dace import data, properties, subsets, transformation, symbolic From ecae36163a35bec82c9d2fc0b147c707bf357100 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 29 Aug 2024 11:48:30 +0200 Subject: [PATCH 016/115] Started with a better implementation for the data dependency test. --- .../dataflow/map_fusion_helper.py | 114 ++++++++---------- .../dataflow/map_fusion_serial.py | 109 ++++++++++++++++- 2 files changed, 156 insertions(+), 67 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index f99def22bf..21df314ad3 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -6,14 +6,13 @@ import itertools import re import copy -from typing import Any, Dict, Iterable, List, Optional, Set, Sequence, TypeVar, Tuple, Union, overload +from typing import Any, Dict, Iterable, List, Optional, Set, Sequence, Tuple, Union, overload import dace from dace import data, properties, subsets, transformation, symbolic from dace.sdfg import SDFG, SDFGState, graph, nodes, validation, replace from dace.transformation import helpers -_T = TypeVar("_T") @properties.make_properties class MapFusionHelper(transformation.SingleStateTransformation): @@ -122,16 +121,6 @@ def can_be_fused( if self.find_parameter_remapping(first_map=map_entry_1.map, second_map=map_entry_2.map) is None: return False - # Test for read-write conflicts - if self.has_read_write_dependency( - map_entry_1= - map_entry_1, - map_exit_2=graph.exit_node(map_entry_2), - state=graph, - sdfg=sdfg, - ): - return False - return True @@ -445,56 +434,6 @@ def _compute_shared_data( self.shared_data[sdfg] = shared_data - def has_read_write_dependency( - self, - map_entry_1: nodes.MapEntry, - map_exit_2: nodes.MapExit, - state: SDFGState, - sdfg: SDFG, - ) -> bool: - """Test if there is a read write dependency between the two maps. - - The function checks if the first map does not read anything from - a data descriptor, the second map writes into. - - Returns: - `True` if there is a conflict between input and outputs, `False` - if there is no conflict. - - Args: - map_entry_1: The entry node of the first map. - map_entry_2: The exit node of the second map. - state: The state on which we operate. - - Note: - The current implementation just computes the set of data that is - used as input and for output. If there is an intersection then - the function considers this as a read write conflict. - """ - - # Determine the set of data that is used as input and as output - inputs_data: Set[str] = set() - outputs_data: Set[str] = set() - for input_edge in state.in_edges(map_entry_1): - input_node: nodes.Node = input_edge.src - if not isinstance(input_node, nodes.AccessNode): - continue - input_desc: data.Data = input_node.desc(sdfg) - inputs_data.add(track_view(input_node, state, sdfg).data if isinstance(input_desc, data.View) else input_node.data) - for output_edge in state.out_edges(map_exit_2): - output_node: nodes.Node = output_edge.dst - if not isinstance(output_node, nodes.AccessNode): - continue - output_desc: data.Data = output_node.desc(sdfg) - outputs_data.add(track_view(output_node, state, sdfg).data if isinstance(output_desc, data.View) else output_node.data) - - # There is no intersection, thus they read and write to distinct data. - if not outputs_data.intersection(inputs_data): - return False - - return True - - def partition_first_outputs( self, state: SDFGState, @@ -922,6 +861,46 @@ def find_upstream_producers( ) +def get_access_set( + scope_node: Union[nodes.MapEntry, nodes.MapExit], + state: SDFGState, +) -> Set[nodes.AccessNode]: + """Computes the access set of a "scope node". + + If `scope_node` is a `MapEntry` node it will operate on the set of incoming + edges and if it is an `MapExit` node on the set of outgoing edges. The + function will then determine all access nodes that have a connection through + these edges to the scope nodes (edges that does not lead to access nodes are + ignored). + The function returns a set that contains all access nodes that were found. + It is important that this set will also contain views. + + Args: + scope_node: The scope node that should be evaluated. + state: The state in which we operate. + """ + if isinstance(scope_node, nodes.MapEntry): + get_edges = lambda node: state.in_edges(node) + other_node = lambda e: e.src + else: + get_edges = lambda node: state.out_edges(node) + other_node = lambda e: e.dst + return { + node + for node in map(other_node, get_edges(scope_node)) + if isinstance(node, nodes.AccessNode) + } + + +def is_view( + node: nodes.AccessNode, + sdfg: SDFG, +) -> bool: + """Tests if `node` points to a view or not.""" + node_desc: data.Data = node.desc(sdfg) + return isinstance(node_desc, data.View) + + def track_view( view: nodes.AccessNode, state: SDFGState, @@ -931,16 +910,21 @@ def track_view( Given the View `view`, the function will trace the view back to the original access node. + For convenience, if `view` is not a `View` but a normal data descriptor, + then the function will return the argument unmodified. Args: view: The view that should be traced. state: The state in which we operate. sdfg: The SDFG on which we operate. """ + + # Test if it is a view at all, if not return the passed node as source. + if not is_view(view, sdfg): + return view + # First determine if the view is used for reading or writing. - assert isinstance(view.desc(sdfg), data.View) curr_edge = dace.sdfg.utils.get_view_edge(state, view) - if curr_edge is None: raise RuntimeError(f"Failed to determine the direction of the view '{view}'.") if curr_edge.dst_conn == "view": @@ -955,7 +939,7 @@ def track_view( # Now trace the view back. org_view = view view = next_node(curr_edge) - while isinstance(view.desc(sdfg), data.View): + while is_view(view, sdfg): curr_edge = dace.sdfg.utils.get_view_edge(state, view) if curr_edge is None: raise RuntimeError(f"View tracing of '{org_view}' failed at note '{view}'.") diff --git a/dace/transformation/dataflow/map_fusion_serial.py b/dace/transformation/dataflow/map_fusion_serial.py index 6667e0cc53..f92d84e2ad 100644 --- a/dace/transformation/dataflow/map_fusion_serial.py +++ b/dace/transformation/dataflow/map_fusion_serial.py @@ -9,11 +9,11 @@ from dace import dtypes, properties, subsets, symbolic, transformation from dace.sdfg import SDFG, SDFGState, graph, nodes -from dace.transformation.dataflow import map_fusion_helper +from dace.transformation.dataflow import map_fusion_helper as mfh @properties.make_properties -class SerialMapFusion(map_fusion_helper.MapFusionHelper): +class SerialMapFusion(mfh.MapFusionHelper): """Specialized replacement for the map fusion transformation that is provided by DaCe. As its name is indicating this transformation is only able to handle Maps that @@ -92,6 +92,15 @@ def can_be_applied( ): return False + # Test for read-write conflicts + if self.has_read_write_dependency( + map_entry_1=map_entry_1, + map_entry_2=map_entry_2, + state=graph, + sdfg=sdfg, + ): + return False + # Two maps can be serially fused if the node decomposition exists and # at least one of the intermediate output sets is not empty. The state # of the pure outputs is irrelevant for serial map fusion. @@ -109,6 +118,102 @@ def can_be_applied( return True + def has_read_write_dependency( + self, + map_entry_1: nodes.MapEntry, + map_entry_2: nodes.MapEntry, + state: SDFGState, + sdfg: SDFG, + ) -> bool: + """Test if there is a read write dependency between the two maps. + + The function checks if the first map does not read anything from + a data descriptor, the second map writes into. + + Returns: + `True` if there is a conflict between input and outputs, `False` + if there is no conflict. + + Args: + map_entry_1: The entry node of the first map. + map_entry_2: The entry node of the second map. + state: The state on which we operate. + + Note: + The current implementation just computes the set of data that is + used as input and for output. If there is an intersection then + the function considers this as a read write conflict. + """ + map_exit_1: nodes.MapExit = state.exit_node(map_entry_1) + map_exit_2: nodes.MapExit = state.exit_node(map_entry_2) + + # Get the read and write sets of the different maps, note that Views + # are not resolved yet. + access_sets: List[Dict[str, nodes.AccessNode]] = [] + for scope_node in [map_entry_1, map_exit_1, map_entry_2, map_exit_2]: + access_sets.append({ + node.data: node + for node in mfh.get_access_set(scope_node, state) + }) + read_map_1, write_map_1, read_map_2, write_map_2 = access_sets + + # It might be possible that there are views, so we have to resolve these sets. + # We also already get the name of the data container. + resolved_sets: List[Set[str]] = [] + for unresolved_set in [read_map_1, write_map_1, read_map_2, write_map_2]: + resolved_sets.append({ + mfh.track_view(node).data if mfh.is_view(node, sdfg) else node.data + for node in unresolved_set.values() + }) + real_read_map_1, real_write_map_1, real_read_map_2, real_write_map_2 = resolved_sets + + # We now test for "structural problems", i.e. problems where the resulting + # SDFG would be invalid, all of these cases are characterized by the fact + # that both maps write to the same data. This is hard or impossible to + # handle, so we forbid all these cases. + if not real_write_map_1.isdisjoint(real_write_map_2): + return True + + # We will now test if there are no conflicts, for this we require that all + # input is distinct from the all the output. + # Must be done after the test above! + if (real_read_map_1 | real_read_map_2).isdisjoint(real_write_map_1 | real_write_map_2): + return False + + # This is the set of nodes that is used to exchange data between the two maps. + # This works, because in the partition function we ensure that such nodes are + # directly connected. + read_map_2_nodes: Set[node.AccessNode] = set(read_map_2.values()) + exchange_set: Dict[str, nodes.AccessNode] = { + name: node + for name, node in write_map_1.items() + if node in read_map_2_nodes + } + + # For simplicity we assume that the nodes used to exchange information can + # not be a View. This is a simplification. + if any(mfh.is_view(exchange_node, sdfg) for exchange_node in exchange_set.values()): + return True + + # This is the names of the node that are used as input of the first map and + # as output of the second map. We can not use the resolved here, because + # we forbid that these nodes are Views. + fused_inout_data_names: Set[str] = set(read_map_1.keys()).intersection(write_map_2.keys()) + + # Because it is hard, we do not allow Views here, because we can not resolve + # access sets (at least I can not). + if any(mfh.is_view(read_map_1[name], sdfg) for name in fused_inout_data_names): + return True + + # This is a case that can not be handled, the above code should filter this + # out, so if you are here, then the above code might have problems. + assert fused_inout_data_names.isdisjoint(exchange_set.keys()), "Constraint violation." + + # TODO: POINTWISE TEST!!! + + return False + + def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: """Performs the serial Map fusing. From 64d07fd232d1716377a6d782ee92ea19a688cadd Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 29 Aug 2024 14:28:39 +0200 Subject: [PATCH 017/115] First version of the pointwise checker in the map fusion. --- .../dataflow/map_fusion_helper.py | 5 +- .../dataflow/map_fusion_serial.py | 246 +++++++++++++++++- 2 files changed, 244 insertions(+), 7 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index 21df314ad3..2c3ecc57b7 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -885,11 +885,14 @@ def get_access_set( else: get_edges = lambda node: state.out_edges(node) other_node = lambda e: e.dst - return { + access_set: Set[nodes.AccessNode] = { node for node in map(other_node, get_edges(scope_node)) if isinstance(node, nodes.AccessNode) } + # As far as I know in a valid SDFG this should not happen. + assert len(access_set) == len({node.data for node in access_set}) + return access_set def is_view( diff --git a/dace/transformation/dataflow/map_fusion_serial.py b/dace/transformation/dataflow/map_fusion_serial.py index f92d84e2ad..b1489b2394 100644 --- a/dace/transformation/dataflow/map_fusion_serial.py +++ b/dace/transformation/dataflow/map_fusion_serial.py @@ -3,7 +3,7 @@ """Implements the serial map fusing transformation.""" import copy -from typing import Any, Dict, List, Set, Union +from typing import Any, Dict, List, Set, Union, Optional import dace from dace import dtypes, properties, subsets, symbolic, transformation @@ -159,6 +159,7 @@ def has_read_write_dependency( # It might be possible that there are views, so we have to resolve these sets. # We also already get the name of the data container. + # Note that `len(real_read_map_1) <= len(read_map_1)` holds because of Views. resolved_sets: List[Set[str]] = [] for unresolved_set in [read_map_1, write_map_1, read_map_2, write_map_2]: resolved_sets.append({ @@ -184,7 +185,7 @@ def has_read_write_dependency( # This works, because in the partition function we ensure that such nodes are # directly connected. read_map_2_nodes: Set[node.AccessNode] = set(read_map_2.values()) - exchange_set: Dict[str, nodes.AccessNode] = { + exchange_nodes: Dict[str, nodes.AccessNode] = { name: node for name, node in write_map_1.items() if node in read_map_2_nodes @@ -192,7 +193,7 @@ def has_read_write_dependency( # For simplicity we assume that the nodes used to exchange information can # not be a View. This is a simplification. - if any(mfh.is_view(exchange_node, sdfg) for exchange_node in exchange_set.values()): + if any(mfh.is_view(exchange_node, sdfg) for exchange_node in exchange_nodes.values()): return True # This is the names of the node that are used as input of the first map and @@ -206,14 +207,247 @@ def has_read_write_dependency( return True # This is a case that can not be handled, the above code should filter this - # out, so if you are here, then the above code might have problems. - assert fused_inout_data_names.isdisjoint(exchange_set.keys()), "Constraint violation." + # out, so if you are here, then the above code might have problems, + # furthermore the code below assumes it. + assert fused_inout_data_names.isdisjoint(exchange_nodes.keys()), "Constraint violation." + assert (len(fused_inout_data_names) > 0) or (len(exchange_nodes) > 0) - # TODO: POINTWISE TEST!!! + # We will now inspect the subsets. + repl_dict = self.find_parameter_remapping(map_entry_1.map, map_entry_2.map) + # First we handle the rw dependency that is given by the whole fused map. + if not self._check_read_write_dependency_fused_map( + map_entry_1=map_entry_1, + map_exit_2=map_exit_2, + inout_data_names=fused_inout_data_names, + read_map_1=read_map_1, + write_map_2=write_map_2, + repl_dict=repl_dict, + state=state, + sdfg=sdfg): + return True # There are rw dependencies. + + # Now we check the exchange nodes, i.e. the common nodes between the maps, + # are point wise. + if not self._check_read_write_dependency_exchange_nodes( + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + exchange_nodes=exchange_nodes, + repl_dict=repl_dict, + state=state, + sdfg=sdfg, + ): + return True # There are rw dependencies. + + # No read write dependency was found. return False + def _check_read_write_dependency_exchange_nodes( + self, + map_exit_1: nodes.MapExit, + map_entry_2: nodes.MapEntry, + exchange_nodes: Dict[str, nodes.AccessNode], + repl_dict: Union[Dict[str, str], None], + state: SDFGState, + sdfg: SDFG, + ) -> bool: + """Checks if there are any rw dependencies in the exchange set. + + Args: + map_exit_1: Exit node of the first (top) map; defines writes. + map_entry_2: Entry node of the second (bottom) map; defines reads. + exchange_nodes: Exchange nodes, i.e. written and read by the maps. + repl_dict: Replacement dict, for renaming the subsets of the second map. + state: The state in which we operate. + sdfg: The containing SDFG. + """ + + for exchange_node in exchange_nodes.values(): + all_subsets: List[subsets.Subset] = [] + + # The reading subsets are defined by the entry of the second map, + # thus we also have to perform some replacing of the parameters. + all_subsets.extend( + self._find_subsets( + node=exchange_node, + scope_node=map_entry_2, + state=state, + sdfg=sdfg, + repl_dict=repl_dict, + ) + ) + + # The writing subset is given by the exit of the first map. No replacing + # is needed, but the node is the same. + all_subsets.extend( + self._find_subsets( + node=exchange_node, + scope_node=map_exit_1, + state=state, + sdfg=sdfg, + repl_dict=None, + ) + ) + + if not self._test_if_subsets_are_point_wise(all_subsets): + return False + + # All subsets are point wise + return True + + + def _check_read_write_dependency_fused_map( + self, + map_entry_1: nodes.MapEntry, + map_exit_2: nodes.MapExit, + inout_data_names: Set[str], + read_map_1: Dict[str, nodes.AccessNode], + write_map_2: Dict[str, nodes.AccessNode], + repl_dict: Union[Dict[str, str], None], + state: SDFGState, + sdfg: SDFG, + ) -> bool: + """Checks the read write dependency that are given by the fused map. + + Args: + map_entry_1: The map entry node of the first (top) map. + map_exit_2: The map exit node of the second (bottom) map. + inout_data_names: Names of all data containers that are conflicting. + read_map_1: All access nodes from which the first map reads (`node.data -> node`). + write_map_2: All access nodes to which the second map writes (`node.data -> node`). + repl_dict: Replacement dict for renaming the second maps iteration parameters. + state: The state in which we operate. + sdfg: The containing SDFG. + """ + for inout_data_name in inout_data_names: + all_subsets: List[subsets.Subset] = [] + + # The subsets that define reading are given by the first map's entry node + all_subsets.extend( + self._find_subsets( + node=read_map_1[inout_data_name], + scope_node=map_entry_1, + state=state, + sdfg=sdfg, + repl_dict=None, + ) + ) + + # While the subsets defining writing are given by the second map's exit + # node, there we also have to apply renaming. + all_subsets.extend( + self._find_subsets( + node=write_map_2[inout_data_name], + scope_node=map_exit_2, + state=state, + sdfg=sdfg, + repl_dict=repl_dict, + ) + ) + + # Now we can test if these subsets are point wise + if not self._test_if_subsets_are_point_wise(all_subsets): + return False + + # All subsets are point wise + return True + + + + def _test_if_subsets_are_point_wise( + self, + subsets_to_check: List[subsets.Subset] + ) -> bool: + """Point wise means that they are all the same. + + If a series of subsets are point wise it means that all Memlets, access + the same data. This is an important property because the whole map fusion + is build upon this. + + Args: + subsets_to_check: The list of subsets that should be checked. + """ + assert len(subsets_to_check) > 1 + + # We will check everything against the master subset. + master_subset = subsets_to_check[0] + for ssidx in range(1, len(subsets_to_check)): + subset = subsets_to_check[ssidx] + if isinstance(subset, subsets.Indices): + subset = subsets.Range.from_indices(subset) + # Do we also need the reverse? See below why. + if any(r != (0, 0, 1) for r in test in subset.offset_new(master_subset,negative=True)): + return False + else: + # The original code used `Range.offset` here, but that one had trouble + # for `r1 = 'j, 0:10'` and `r2 = 'j, 0`. The solution would be to test + # symmetrically, i.e. `r1 - r2` and `r2 - r1`. However, if we would + # have `r2_1 = 'j, 0:10'` it consider it as failing, which is not + # what we want. Thus we will use symmetric cover. + if not master_subset.covers(subset): + return False + if not subset.covers(master_subset): + return False + + # All subsets are equal to the master subset, thus they are equal to each other. + # This means that the data accesses, described by this transformation is + # point wise + return True + + + def _find_subsets( + self, + node: nodes.AccessNode, + scope_node: Union[nodes.MapExit, nodes.MapEntry], + state: SDFGState, + sdfg: SDFG, + repl_dict: Optional[Dict[str, str]], + ) -> List[subsets.Subset]: + """Finds all subsets involving node `node`. + + Args: + node: The access node that should be examined. + scope_node: We are only interested in data that flows through this node. + state: The state in which we operate. + sdfg: The SDFG object. + """ + + # Is the node used for reading or for writing. + # This influences how we have to proceed. + if isinstance(scope_node, nodes.MapEntry): + used_for_reading = True + edges_to_inspect = state.in_edges(scope_node) + test_edge = lambda e: (e.src == node) + get_subset = lambda e: e.data.src_subset + else: + used_for_reading = False + edges_to_inspect = state.out_edges(scope_node) + test_edge = lambda e: e.dst == node + get_subset = lambda e: e.data.dst_subset + + found_subsets: List[subsets.Subset] = [] + for edge in edges_to_inspect: + if not test_edge(edge): + continue + if used_for_reading: + consumer_or_producer = mfh.find_downstream_consumers(state, begin=edge) + else: + consumer_or_producer = mfh.find_upstream_producers(state, begin=edge) + found_subsets.extend(get_subset(e) for _, e in consumer_or_producer) + assert len(found_subsets) > 0, f"Could not find any subsets." + assert not any(subset is None for subset in found_subsets) + + # The deepcopy is needed if we would do renaming. + found_subsets = copy.deepcopy(found_subsets) + if repl_dict: + for subset in found_subsets: + # Replace happens in place + symbolic.safe_replace(repl_dict, subset.replace) + + return found_subsets + + def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: """Performs the serial Map fusing. From 33a0edf55c2fec768d30960f654152b3f4c0b1ce Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 29 Aug 2024 14:45:00 +0200 Subject: [PATCH 018/115] Updated some test cases. --- tests/transformations/mapfusion_data_races_test.py | 4 +++- tests/transformations/mapfusion_test.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/transformations/mapfusion_data_races_test.py b/tests/transformations/mapfusion_data_races_test.py index 13a71370e1..0466a32551 100644 --- a/tests/transformations/mapfusion_data_races_test.py +++ b/tests/transformations/mapfusion_data_races_test.py @@ -87,6 +87,7 @@ def test_rw_data_race_3_mf(): def test_rw_data_race_4_mf(): # It is technically possible to fuse it, because there is only a point wise dependency. + # However, it is very hard to detect and handle correct. sdfg = rw_data_race_4.to_sdfg(simplify=True) sdfg.apply_transformations_repeated(MapFusion) map_entry_nodes = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, nodes.MapEntry)] @@ -95,8 +96,9 @@ def test_rw_data_race_4_mf(): if __name__ == "__main__": test_rw_data_race() - test_rw_data_race_2_mf() test_rw_data_race_2_sgf() + test_rw_data_race_2_mf() test_rw_data_race_3_sgf() test_rw_data_race_3_mf() test_rw_data_race_4_mf() + print("SUCCESS") diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 8055afee17..c9d77608c8 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -388,3 +388,4 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3 test_fusion_with_empty_memlet() test_fusion_with_nested_sdfg_0() test_fusion_with_nested_sdfg_1() + print("SUCCESS") From dbb989eea499c76219da754a3c8b878c3aefa3d0 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 29 Aug 2024 15:23:02 +0200 Subject: [PATCH 019/115] Changed how the `_find_subsets()` works in the dependency tests. Before it was going to look for the memlet of the consumer or producer. However, one should actually only look at the memlets that are adjacent to the scope node. At least this is how the original worked. I noticed this because of the `buffer_tiling_test.py::test_basic()` test. I was not yet focused on maps that were nested and not multidimensional. It seems that the transformation has some problems there. --- dace/transformation/dataflow/map_fusion_serial.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_serial.py b/dace/transformation/dataflow/map_fusion_serial.py index b1489b2394..bdba32885b 100644 --- a/dace/transformation/dataflow/map_fusion_serial.py +++ b/dace/transformation/dataflow/map_fusion_serial.py @@ -406,6 +406,10 @@ def _find_subsets( ) -> List[subsets.Subset]: """Finds all subsets involving node `node`. + The function will not start a search for all consumer/producers. + Instead it will locate the edges which is immediately inside the + map scope. + Args: node: The access node that should be examined. scope_node: We are only interested in data that flows through this node. @@ -423,7 +427,7 @@ def _find_subsets( else: used_for_reading = False edges_to_inspect = state.out_edges(scope_node) - test_edge = lambda e: e.dst == node + test_edge = lambda e: (e.dst == node) get_subset = lambda e: e.data.dst_subset found_subsets: List[subsets.Subset] = [] @@ -431,10 +435,10 @@ def _find_subsets( if not test_edge(edge): continue if used_for_reading: - consumer_or_producer = mfh.find_downstream_consumers(state, begin=edge) + inner_edges = state.out_edges_by_connector(scope_node, "OUT_" + edge.dst_conn[3:]) else: - consumer_or_producer = mfh.find_upstream_producers(state, begin=edge) - found_subsets.extend(get_subset(e) for _, e in consumer_or_producer) + inner_edges = state.in_edges_by_connector(scope_node, "IN_" + edge.src_conn[4:]) + found_subsets.extend(get_subset(e) for e in inner_edges) assert len(found_subsets) > 0, f"Could not find any subsets." assert not any(subset is None for subset in found_subsets) From e142881d1fb6a25aa6a9c23a517cfa6c9ed5bc4d Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 30 Aug 2024 11:09:43 +0200 Subject: [PATCH 020/115] Updated the map fusion's partitioning function. Whet it now cheks for covering (i.e. if the information to exchange is enough) it will now no longer decend into the maps, but only inspect the first outgoing/incomming edges of the map entrie and exit. I noticed that the other way was to restrictive, especially for map tiling. --- .../dataflow/map_fusion_helper.py | 203 ++++++++---------- 1 file changed, 90 insertions(+), 113 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index 2c3ecc57b7..f38bfa0f7a 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -494,6 +494,7 @@ def partition_first_outputs( first_map=map_exit_1.map, second_map=map_entry_2.map, ) + assert repl_dict is not None # Set of intermediate nodes that we have already processed. processed_inter_nodes: Set[nodes.Node] = set() @@ -508,8 +509,15 @@ def partition_first_outputs( return None processed_inter_nodes.add(intermediate_node) + # The intermediate can only have one incoming degree. It might be possible + # to handle multiple incoming edges, if they all come from the top map. + # However, the resulting SDFG might be invalid. + if state.in_degree(intermediate_node) != 1: + return None + # Now let's look at all nodes that are downstream of the intermediate node. # This, among other things, will tell us, how we have to handle this node. + # NOTE: The traversal will stop at the second map. downstream_nodes = all_nodes_between( graph=state, begin=intermediate_node, @@ -528,6 +536,12 @@ def partition_first_outputs( # cases, as handling them is essentially rerouting an edge, whereas # handling intermediate nodes is much more complicated. + # If `downstream_nodes` is empty, this means that the second map entry + # was found immediately, we only allow the case that there is one + # connecting Memlet. + if (len(downstream_nodes) == 0) and state.out_degree(intermediate_node) != 1: + return None + # For us an intermediate node must always be an access node, because # everything else we do not know how to handle. It is important that # we do not test for non transient data here, because they can be @@ -557,125 +571,88 @@ def partition_first_outputs( # of the first map exit, but there is only one edge leaving the exit. # It is complicate to handle this, so for now we ignore it. # TODO(phimuell): Handle this case properly. - # The main reason why we forbid this is because it becomes a bit tricky - # to figuring out the size of the intermediate. - inner_collector_edges = list(state.in_edges_by_connector(intermediate_node, "IN_" + out_edge.src_conn[3:])) - if len(inner_collector_edges) > 1: + # To handle this we need to associate a consumer edge (the outgoing edges + # of the second map) with exactly one producer. + producer_edges: List[graph.MultiConnectorEdge[dace.Memlet]] = list(state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:])) + if len(producer_edges) > 1: return None - # We now look for all producers inside the top map. We need this information - # later to determine if the point wise output of the top map is enough - # to serve as (point wise) input for the second map. - # In addition we will check if there is a WCR edge is present, which we - # can not handle. - producers = find_upstream_producers(state, out_edge) - - # More than one producer is not supported. - if len(producers) != 1: - return None - producer_node, producer_edge = next(iter(producers)) - - # If the producer is a view, then we give up. It is possible to handle, - # but also very complicated. - if isinstance(producer_node, nodes.AccessNode) and isinstance(producer_node.desc(sdfg), data.View): - return None - - # We do not allow that the edge has WCR. - if producer_edge.data.wcr is not None: + # Now check the constraints we have on the producers. + # - The source of the producer can not be a view (we do not handle this) + # - The edge shall also not be a reduction edge. + # - Defined location to where they write. + # Furthermore, we will also extract the subsets, i.e. the location they + # modify inside the intermediate array. + producer_subsets: List[subsets.Subset] = [] + for producer_edge in producer_edges: + if isinstance(producer_edge.src, nodes.AccessNode) and is_view(producer_edge.src, sdfg): + return None + if producer_edge.data.wcr is not None: + return None + if producer_edge.data.dst_subset is None: + return None + producer_subsets.append(producer_edge.data.dst_subset) + + # Now we determine the consumer of nodes. For this we are using the edges + # leaves the second map entry. We could find the final consumers (e.g. + # Tasklets), however, this might make problems as they depend also on + # symbols defined by nested maps. However, we are not interested in edges, + # but actually what they read, i.e. their source subset. + # In any case there can be at most one connection between the intermediate + # and the second map entry. + found_second_map = False + consumer_subsets: List[subsets.Subset] = [] + for intermediate_node_out_edges in state.out_edges(intermediate_node): + # Check if we have not reached the second map entry. + # This happens if the intermediate node is a shared node. + # However, we are only allowed to find the second map once. + if intermediate_node_out_edges.dst is not map_entry_2: + continue + if found_second_map: + # TODO(phimuell): Lift this restriction. + return None + found_second_map = True + assert intermediate_node_out_edges.dst_conn.startswith("IN_") + consumer_subsets.extend( + e.data.src_subset + for e in state.out_edges_by_connector(map_entry_2, "OUT_" + intermediate_node_out_edges.dst_conn[3:]) + ) + # The subsets are not set correctly, so we give up. + if any(consumer_subset is None for consumer_subset in consumer_subsets): return None - - if len(downstream_nodes) == 0: - # TODO(phimuell): Refactor this if such that the loop is only there once. - - # There is nothing between intermediate node and the entry of the - # second map, thus the edge belongs either in `\mathbb{S}` or - # `\mathbb{E}`. - - # If the intermediate access node as more than one outgoing edge - # it means (because of `downstream_nodes`) that it has multiple - # connections to the second map. We do not allow this. - # TODO(phimuell): Handle this case. - if state.out_degree(intermediate_node) != 1: + assert len(consumer_subsets) != 0 + + # Furthermore, the consumer still uses the original symbols of the + # second map, so we must rename them. + if repl_dict: + consumer_subsets = copy.deepcopy(consumer_subsets) + for consumer_subset in consumer_subsets: + symbolic.safe_replace(mapping=repl_dict, replace_callback=consumer_subset.replace) + + # Now we are checking if a single iteration of the first (top) map + # can satisfy all data requirements of the second (bottom) map. + # For this we look if the producer covers the consumer. A consumer must + # be covered by exactly one producer. + for consumer_subset in consumer_subsets: + nb_coverings = sum(producer_subset.covers(consumer_subset) for producer_subset in producer_subsets) + if nb_coverings != 1: return None - # We can fuse the maps only if the producer, i.e. the top map, - # represented by `producer_edge`, is gives us enough information. - producer_subset: subsets.Range = producer_edge.data.dst_subset - consumers = find_downstream_consumers(state=state, begin=intermediate_node) - - for consumer_node, consumer_edge in consumers: - if consumer_node is map_entry_2: # Dynamic map range. - return None - if isinstance(consumer_node, nodes.LibraryNode): - # TODO(phimuell): Allow some library nodes. - return None - - # Tests if consumer consume less or equal the amount we generate. - # If the source of the consumer is not set, we can not check what it reads. - if consumer_edge.data.src_subset is None: - return None - # For that we have to perform a replace operation of the consumer . - mod_consumer_subset = copy.deepcopy(consumer_edge.data.src_subset) - symbolic.safe_replace(mapping=repl_dict, replace_callback=mod_consumer_subset.replace) - - if not producer_subset.covers(mod_consumer_subset): - return None - - # Note that "remove" has a special meaning here, regardless of the - # output of the check function, from within the second map we remove - # the intermediate, it has more the meaning of "do we need to - # reconstruct it after the second map again?" - if self.is_shared_data(intermediate_node, sdfg): - shared_outputs.add(out_edge) - else: - exclusive_outputs.add(out_edge) - continue - - else: - # There is not only a single connection from the intermediate node to - # the second map, but the intermediate has more connections, thus - # the node might belong to the shared output. Of the many different - # possibilities, we only consider a single case: - # - The intermediate has a single connection to the second map, that - # fulfills the restriction outlined above. - # - All other connections have no connection to the second map. - - # We can fuse the maps only if the producer, i.e. the top map, - # represented by `producer_edge`, is gives us enough information. - producer_subset: subsets.Range = producer_edge.data.dst_subset - consumers = find_downstream_consumers(state=state, begin=intermediate_node) - - found_second_entry = False - for edge in state.out_edges(intermediate_node): - if edge.dst is map_entry_2: - if found_second_entry: # The second map was found again. - return None - found_second_entry = True - consumers = find_downstream_consumers(state=state, begin=edge) - for consumer_node, consumer_edge in consumers: - if consumer_node is map_entry_2: # Dynamic map range. - return None - if isinstance(consumer_node, nodes.LibraryNode): - # TODO(phimuell): Allow some library nodes. - return None - - # Tests if consumer consume less or equal the amount we generate. - # If the source of the consumer is not set, we can not check what it reads. - if consumer_edge.data.src_subset is None: - return None - # For that we have to perform a replace operation of the consumer . - mod_consumer_subset = copy.deepcopy(consumer_edge.data.src_subset) - symbolic.safe_replace(mapping=repl_dict, replace_callback=mod_consumer_subset.replace) - if not producer_subset.covers(mod_consumer_subset): - return None - else: - # Ensure that there is no path that leads to the second map. - after_intermdiate_node = all_nodes_between(graph=state, begin=edge.dst, end=map_entry_2) - if after_intermdiate_node is not None: - return None - # If we are here, then we know that the node is a shared output + # After we have ensured coverage, we have to decide if the intermediate + # node can be removed (`\mathbb{E}`) or has to be restored (`\mathbb{S}`). + # Note that "removed" here means that it is reconstructed by a new + # output of the second map. + if len(downstream_nodes) != 0: + # The intermediate node is connected to more node inside this state, + # that are not inside the map, so we must keep it alive. shared_outputs.add(out_edge) - continue + elif self.is_shared_data(intermediate_node, sdfg): + # The intermediate data is refered to somewhere else. + # So it must be passed. + shared_outputs.add(out_edge) + else: + # The intermediate can be removed, as it is not used anywhere else. + exclusive_outputs.add(out_edge) assert exclusive_outputs or shared_outputs or pure_outputs assert len(processed_inter_nodes) == sum(len(x) for x in [pure_outputs, exclusive_outputs, shared_outputs]) From ff018f4108d4c9d3e1a07c8262a7cc85384fde31 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 30 Aug 2024 14:00:49 +0200 Subject: [PATCH 021/115] The shared data cache can not be dumped. Otherwise we can end up in recursion. --- dace/transformation/dataflow/map_fusion_helper.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index f38bfa0f7a..fcc810aad3 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -47,6 +47,8 @@ class MapFusionHelper(transformation.SingleStateTransformation): value_type=set, #[str] default=None, allow_none=True, + optional=True, # Do not serialize. + optional_condition=lambda _: False, desc="Maps SDFGs to the set of data that can not be removed. " "The variable acts as a cache, and is managed by 'is_shared_data()'.", ) From ec6339a3c8d45c9dacf80dea1c58a12b1d13708e Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 30 Aug 2024 14:02:05 +0200 Subject: [PATCH 022/115] Reworked how the serial fusion adjuists Memlets. Before it was replacing the elimated variables by zero. Which actually worked pretty good, but I have now changed that such that `offset()` is used. I am not sure why I used `replace` in the first place, but I think that there was an issue. However, I am not sure. --- .../dataflow/map_fusion_serial.py | 100 ++++++++---------- 1 file changed, 44 insertions(+), 56 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_serial.py b/dace/transformation/dataflow/map_fusion_serial.py index bdba32885b..6c2eacd31b 100644 --- a/dace/transformation/dataflow/map_fusion_serial.py +++ b/dace/transformation/dataflow/map_fusion_serial.py @@ -572,15 +572,6 @@ def handle_intermediate_set( after this function has run the state is (most likely) invalid. """ - # Essentially this function removes the AccessNode between the two maps. - # However, we still need some temporary memory that we can use, which is - # just much smaller, i.e. a scalar. But all Memlets inside the second map - # assumes that the intermediate memory has the bigger shape. - # To fix that we will create this replacement dict that will replace all - # occurrences of the iteration variables of the second map with zero. - # Note that this is still not enough as the dimensionality might be different. - memlet_repl: Dict[str, int] = {str(param): 0 for param in map_entry_2.map.params} - # Now we will iterate over all intermediate edges and process them. # If not stated otherwise the comments assume that we run in exclusive mode. for out_edge in intermediate_outputs: @@ -649,52 +640,43 @@ def handle_intermediate_set( # New we will reroute the output Memlet, thus it will no longer pass # through the Map exit but through the newly created intermediate. - # we will delete the previous edge later. - pre_exit_memlet: dace.Memlet = pre_exit_edge.data - new_pre_exit_memlet = copy.deepcopy(pre_exit_memlet) - - # We might operate on a different array, but the check below, ensures - # that we do not change the direction of the Memlet. - assert pre_exit_memlet.data == inter_name - new_pre_exit_memlet.data = new_inter_name - - # Now we have to modify the subset of the Memlet. - # Before the subset of the Memlet was dependent on the Map variables, - # however, this is no longer the case, as we removed them. This change - # has to be reflected in the Memlet. - # NOTE: Assert above ensures that the below is correct. - new_pre_exit_memlet.replace(memlet_repl) - if is_scalar: - new_pre_exit_memlet.subset = "0" - new_pre_exit_memlet.other_subset = None - else: - new_pre_exit_memlet.subset.pop(squeezed_dims) - - # Now we create the new edge between the producer and the new output - # (the new intermediate node). We will remove the old edge further down. + # NOTE: We will delete the previous edge later. new_pre_exit_edge = state.add_edge( pre_exit_edge.src, pre_exit_edge.src_conn, new_inter_node, None, - new_pre_exit_memlet, + sdfg.make_array_memlet(new_inter_name), ) - # We just have handled the last Memlet, but we must actually handle the - # whole producer side, i.e. the scope of the top Map. - for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children(): + # Get the subset that defined into which part of the old intermediate + # the old output edge wrote to. We need that to adjust the producer + # Memlets, since they now write into the new (smaller) intermediate. + assert pre_exit_edge.data.data == inter_name + assert pre_exit_edge.data.dst_subset is not None + old_pre_exit_edge_subset = pre_exit_edge.data.dst_subset + + # We now handle the MemletTree defined by this edge. + # The newly created edge, only handled the last collection step. + for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children(include_self=True): producer_edge = producer_tree.edge - # Ensure the correctness of the rerouting below. + # Exclude the edge we just have created. + if producer_edge is new_pre_exit_edge: + continue + + # Associate the (already existing) Memlet with the new data. # TODO(phimuell): Improve the code below to remove the check. assert producer_edge.data.data == inter_name - - # Will not change the direction, because of test above! producer_edge.data.data = new_inter_name - producer_edge.data.replace(memlet_repl) + if is_scalar: producer_edge.data.dst_subset = "0" elif producer_edge.data.dst_subset is not None: + # Since we now write into a smaller memory patch, we must + # compensate for that. We do this by substracting where the write + # originally had begun. + producer_edge.data.dst_subset.offset(old_pre_exit_edge_subset, negative=True) producer_edge.data.dst_subset.pop(squeezed_dims) # Now after we have handled the input of the new intermediate node, @@ -723,17 +705,18 @@ def handle_intermediate_set( for inner_edge in state.out_edges_by_connector(map_entry_2, out_conn_name): assert inner_edge.data.data == inter_name # DIRECTION!! - # The create the first Memlet to transmit information, within - # the second map, we do this again by copying and modifying - # the original Memlet. - # NOTE: Test above is important to ensure the direction of the - # Memlet and the correctness of the code below. + # As for the producer side, we now read from a smaller array, + # So we must offset them, we use the original edge for this. + assert inner_edge.data.src_subset is not None, f"{inner_edge} | {inner_edge.data} | {inner_edge.data.src_subset} | {inner_edge.data.dst_subset} " + inner_edge_correction_offset = inner_edge.data.src_subset + + # Now we create a new connection that instead reads from the new + # intermediate, instead of the old one. For this we use the + # old Memlet as template. However it is not fully initialized. new_inner_memlet = copy.deepcopy(inner_edge.data) - new_inner_memlet.replace(memlet_repl) - new_inner_memlet.data = new_inter_name # Because of the assert above, this will not change the direction. + new_inner_memlet.data = new_inter_name - # Now remove the old edge, that started the second map entry. - # Also add the new edge that started at the new intermediate. + # Now we replace the edge from the SDFG. state.remove_edge(inner_edge) new_inner_edge = state.add_edge( new_inter_node, @@ -743,22 +726,27 @@ def handle_intermediate_set( new_inner_memlet, ) - # Now we do subset modification to ensure that nothing failed. + # Now modifying the Memlet, we do it after the insertion to make + # sure that the Memlet was properly initialized. if is_scalar: - new_inner_memlet.src_subset = "0" + new_inner_memlet.subset = "0" elif new_inner_memlet.src_subset is not None: + new_inner_memlet.src_subset.offset(inner_edge_correction_offset, negative=True) new_inner_memlet.src_subset.pop(squeezed_dims) - # Now clean the Memlets of that tree to use the new intermediate node. - for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children(): + # Now we have to make sure that all consumers are properly updated. + for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children(include_self=True): + if consumer_tree.edge is new_inner_edge: + continue + assert consumer_tree.edge.data.data == inter_name + consumer_edge = consumer_tree.edge - assert consumer_edge.data.data == inter_name consumer_edge.data.data = new_inter_name - consumer_edge.data.replace(memlet_repl) if is_scalar: consumer_edge.data.src_subset = "0" elif consumer_edge.data.subset is not None: - consumer_edge.data.subset.pop(squeezed_dims) + consumer_edge.data.src_subset.offset(inner_edge_correction_offset, negative=True) + consumer_edge.data.src_subset.pop(squeezed_dims) # The edge that leaves the second map entry was already deleted. # We will now delete the edges that brought the data. From 9267ea97126c3fc627a0d8e81dc30866398af4bf Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 30 Aug 2024 14:09:56 +0200 Subject: [PATCH 023/115] Buffer tiling now finally works. --- dace/transformation/dataflow/buffer_tiling.py | 8 +++++++- tests/buffer_tiling_test.py | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/dace/transformation/dataflow/buffer_tiling.py b/dace/transformation/dataflow/buffer_tiling.py index a418e167d8..b7d7a5607b 100644 --- a/dace/transformation/dataflow/buffer_tiling.py +++ b/dace/transformation/dataflow/buffer_tiling.py @@ -98,7 +98,13 @@ def apply(self, graph, sdfg): # Fuse maps some_buffer = next(iter(buffers)) # some dummy to pass to MapFusion.apply_to() - MapFusion.apply_to(sdfg, first_map_exit=tile_map1_exit, array=some_buffer, second_map_entry=tile_map2_entry) + MapFusion.apply_to( + sdfg, + map_exit1=tile_map1_exit, + access_node=some_buffer, + map_entry2=tile_map2_entry, + verify=True, + ) # Optimize the simple cases map1_entry.range.ranges = [ diff --git a/tests/buffer_tiling_test.py b/tests/buffer_tiling_test.py index 52477dcc72..03635a7721 100644 --- a/tests/buffer_tiling_test.py +++ b/tests/buffer_tiling_test.py @@ -78,6 +78,7 @@ def _semantic_eq(tile_sizes, program): count = sdfg.apply_transformations(BufferTiling, options={'tile_sizes': tile_sizes}) assert count > 0 + sdfg.validate() sdfg(w3=w3, w5=w5, A=A, B=B2, I=A.shape[0], J=A.shape[1]) assert np.allclose(B1, B2) From fc2db8a5a5ccc27132988418c54cba8ab05cdb1e Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 30 Aug 2024 14:11:12 +0200 Subject: [PATCH 024/115] The Mapreduce now also works. --- dace/transformation/dataflow/mapreduce.py | 5 +++-- tests/transformations/mapfusion_test.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/dace/transformation/dataflow/mapreduce.py b/dace/transformation/dataflow/mapreduce.py index 0eef39c3cb..11445344ec 100644 --- a/dace/transformation/dataflow/mapreduce.py +++ b/dace/transformation/dataflow/mapreduce.py @@ -216,8 +216,9 @@ def apply(self, graph: SDFGState, sdfg: SDFG): map_entry, _ = map_collapse.apply(graph, sdfg) map_fusion = MapFusion() + # What is with the array? map_fusion.setup_match(sdfg, self.cfg_id, self.state_id, { - MapFusion.first_map_exit: graph.node_id(self.tmap_exit), - MapFusion.second_map_entry: graph.node_id(map_entry), + MapFusion.map_exit1: graph.node_id(self.tmap_exit), + MapFusion.map_entry2: graph.node_id(map_entry), }, 0) map_fusion.apply(graph, sdfg) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index c9d77608c8..920830248e 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -4,6 +4,7 @@ import numpy as np import os import dace +import copy from dace import SDFG, SDFGState from dace.sdfg import nodes @@ -31,6 +32,7 @@ def apply_fusion( number of maps. """ num_maps_before = count_node(sdfg, nodes.MapEntry) + org_sdfg = copy.deepcopy(sdfg) sdfg.apply_transformations_repeated(MapFusion, validate=True, validate_all=True) num_maps_after = count_node(sdfg, nodes.MapEntry) @@ -378,11 +380,11 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3 if __name__ == '__main__': + test_fusion_with_transient() test_fusion_rename() test_fusion_simple() test_multiple_fusions() test_fusion_chain() - test_fusion_with_transient() test_fusion_with_transient_scalar() test_fusion_with_inverted_indices() test_fusion_with_empty_memlet() From 4d9f11dac94ec94c2d5024d83eca5b1651f9c999 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 30 Aug 2024 14:41:28 +0200 Subject: [PATCH 025/115] Added a test to the map fusion stuff that ensures that the shared block is taken. --- tests/transformations/mapfusion_test.py | 26 +++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 920830248e..c5e8e8f240 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -142,6 +142,14 @@ def fusion_chain(A: dace.float32[10, 20], B: dace.float32[10, 20]): tmp2 = tmp1 * 4 B[:] = tmp2 + 5 + +@dace.program +def fusion_shared_output(A: dace.float32[10, 20], B: dace.float32[10, 20], C: dace.float32[10, 20]): + tmp = A + 3 + B[:] = tmp * 4 + C[:] = tmp / 6 + + def test_fusion_simple(): sdfg = fusion_simple.to_sdfg() sdfg = apply_fusion(sdfg, final_maps=1) @@ -155,6 +163,7 @@ def test_fusion_simple(): print('Difference:', diff) assert diff <= 1e-3 + def test_fusion_rename(): sdfg = fusion_rename.to_sdfg() sdfg = apply_fusion(sdfg, final_maps=1) @@ -169,6 +178,22 @@ def test_fusion_rename(): assert diff <= 1e-3 +def test_fusion_shared(): + sdfg = fusion_shared_output.to_sdfg() + sdfg = apply_fusion(sdfg) + + A = np.random.rand(10, 20).astype(np.float32) + B = np.random.rand(10, 20).astype(np.float32) + C = np.random.rand(10, 20).astype(np.float32) + + B_res = (A + 3) * 4 + C_res = (A + 3) / 6 + sdfg(A=A, B=B, C=C) + + assert np.allclose(B_res, B) + assert np.allclose(C_res, C) + + def test_multiple_fusions(): sdfg = multiple_fusions.to_sdfg() @@ -380,6 +405,7 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3 if __name__ == '__main__': + test_fusion_shared() test_fusion_with_transient() test_fusion_rename() test_fusion_simple() From 2b91465b29b85a80cb30c24f5c2fd7dbf3de0cac Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 30 Aug 2024 14:48:44 +0200 Subject: [PATCH 026/115] Added a test for the indirect accesses case. --- tests/transformations/mapfusion_test.py | 56 ++++++++++++++++++------- 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index c5e8e8f240..2c4e93a8c1 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -143,6 +143,23 @@ def fusion_chain(A: dace.float32[10, 20], B: dace.float32[10, 20]): B[:] = tmp2 + 5 +@dace.program +def fusion_with_transient(A: dace.float64[2, 20]): + res = np.ndarray([2, 20], dace.float64) + for i in dace.map[0:20]: + for j in dace.map[0:2]: + with dace.tasklet: + a << A[j, i] + t >> res[j, i] + t = a * a + for i in dace.map[0:20]: + for j in dace.map[0:2]: + with dace.tasklet: + t << res[j, i] + o >> A[j, i] + o = t * 2 + + @dace.program def fusion_shared_output(A: dace.float32[10, 20], B: dace.float32[10, 20], C: dace.float32[10, 20]): tmp = A + 3 @@ -150,6 +167,12 @@ def fusion_shared_output(A: dace.float32[10, 20], B: dace.float32[10, 20], C: da C[:] = tmp / 6 +@dace.program +def fusion_indirect_access(A: dace.float32[100], B: dace.float32[100], idx: dace.int32[30], out: dace.float32[30]): + tmp = (A + B * 2) + 3 + out[:] = tmp[idx] + + def test_fusion_simple(): sdfg = fusion_simple.to_sdfg() sdfg = apply_fusion(sdfg, final_maps=1) @@ -194,6 +217,21 @@ def test_fusion_shared(): assert np.allclose(C_res, C) +def test_indirect_accesses(): + sdfg = fusion_indirect_access.to_sdfg() + sdfg = apply_fusion(sdfg, final_maps=2) + + A = np.random.rand(100).astype(np.float32) + B = np.random.rand(100).astype(np.float32) + idx = ((np.random.rand(30) * 100) % 100).astype(np.int32) + out = np.zeros(shape=30, dtype=np.float32) + + res = ((A + B * 2) + 3)[idx] + sdfg(A=A, B=B, idx=idx, out=out) + + assert np.allclose(res, out) + + def test_multiple_fusions(): sdfg = multiple_fusions.to_sdfg() @@ -228,22 +266,6 @@ def test_fusion_chain(): assert diff <= 1e-4 -@dace.program -def fusion_with_transient(A: dace.float64[2, 20]): - res = np.ndarray([2, 20], dace.float64) - for i in dace.map[0:20]: - for j in dace.map[0:2]: - with dace.tasklet: - a << A[j, i] - t >> res[j, i] - t = a * a - for i in dace.map[0:20]: - for j in dace.map[0:2]: - with dace.tasklet: - t << res[j, i] - o >> A[j, i] - o = t * 2 - def test_fusion_with_transient(): A = np.random.rand(2, 20) @@ -405,6 +427,7 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3 if __name__ == '__main__': + test_indirect_accesses() test_fusion_shared() test_fusion_with_transient() test_fusion_rename() @@ -417,3 +440,4 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3 test_fusion_with_nested_sdfg_0() test_fusion_with_nested_sdfg_1() print("SUCCESS") + From 73f4415453c14d46926fa26b7e52b24e5a62526a Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 30 Aug 2024 14:52:26 +0200 Subject: [PATCH 027/115] Updated the heat 3d test. It now ensures that the fusion is now done. --- tests/npbench/polybench/heat_3d_test.py | 12 ++++++++++++ tests/npbench/polybench/jacobi_2d_test.py | 1 + 2 files changed, 13 insertions(+) diff --git a/tests/npbench/polybench/heat_3d_test.py b/tests/npbench/polybench/heat_3d_test.py index e058914fd3..75ad902c4b 100644 --- a/tests/npbench/polybench/heat_3d_test.py +++ b/tests/npbench/polybench/heat_3d_test.py @@ -64,10 +64,22 @@ def run_heat_3d(device_type: dace.dtypes.DeviceType): A_ref = np.copy(A) B_ref = np.copy(B) + def count_maps(sdfg: dc.SDFG) -> int: + nb_maps = 0 + for _, state in sdfg.all_nodes_recursive(): + node: dc.SDFGState + for node in state.nodes(): + if isinstance(node, dc.sdfg.nodes.MapEntry): + nb_maps += 1 + return nb_maps + if device_type in {dace.dtypes.DeviceType.CPU, dace.dtypes.DeviceType.GPU}: # Parse the SDFG and apply auto-opt sdfg = heat_3d_kernel.to_sdfg() + initial_maps = count_maps(sdfg) sdfg = auto_optimize(sdfg, device_type) + after_maps = count_maps(sdfg) + assert after_maps < initial_maps, f"Expected less maps, initially {initial_maps} many maps, but after optimization {after_maps}" sdfg(TSTEPS, A, B, N=N) elif device_type == dace.dtypes.DeviceType.FPGA: # Parse SDFG and apply FPGA friendly optimization diff --git a/tests/npbench/polybench/jacobi_2d_test.py b/tests/npbench/polybench/jacobi_2d_test.py index bc2d5a4f2b..61982c427f 100644 --- a/tests/npbench/polybench/jacobi_2d_test.py +++ b/tests/npbench/polybench/jacobi_2d_test.py @@ -47,6 +47,7 @@ def run_jacobi_2d(device_type: dace.dtypes.DeviceType): # Parse the SDFG and apply autopot sdfg = kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) + sdfg(A=A, B=B, TSTEPS=TSTEPS, N=N) elif device_type == dace.dtypes.DeviceType.FPGA: From 94ecd192ff79f7223cce8c4ea77196112398b371 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 30 Aug 2024 15:43:54 +0200 Subject: [PATCH 028/115] Fixed an error in the parallel map fusion. --- dace/transformation/dataflow/map_fusion_parallel.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dace/transformation/dataflow/map_fusion_parallel.py b/dace/transformation/dataflow/map_fusion_parallel.py index eae550c5e3..414a565d7e 100644 --- a/dace/transformation/dataflow/map_fusion_parallel.py +++ b/dace/transformation/dataflow/map_fusion_parallel.py @@ -103,7 +103,6 @@ def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: Map to the corresponding nodes of the first Map. Afterwards the nodes of the second Map are removed. """ - assert self.map_parameter_compatible(self.map_entry1.map, self.map_entry2.map, graph, sdfg) map_entry_1: nodes.MapEntry = self.map_entry1 map_exit_1: nodes.MapExit = graph.exit_node(map_entry_1) From 3f3f8a3c8432dff1f2f0b6dfe36690731f46d75a Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 30 Aug 2024 15:44:24 +0200 Subject: [PATCH 029/115] Added a test for the parallel map fusion transformations. --- tests/transformations/mapfusion_test.py | 78 ++++++++++++++++++++++++- 1 file changed, 76 insertions(+), 2 deletions(-) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 2c4e93a8c1..4236cd6a46 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -8,7 +8,7 @@ from dace import SDFG, SDFGState from dace.sdfg import nodes -from dace.transformation.dataflow import MapFusion +from dace.transformation.dataflow import SerialMapFusion, ParallelMapFusion def count_node(sdfg: SDFG, node_type): @@ -33,7 +33,7 @@ def apply_fusion( """ num_maps_before = count_node(sdfg, nodes.MapEntry) org_sdfg = copy.deepcopy(sdfg) - sdfg.apply_transformations_repeated(MapFusion, validate=True, validate_all=True) + sdfg.apply_transformations_repeated(SerialMapFusion, validate=True, validate_all=True) num_maps_after = count_node(sdfg, nodes.MapEntry) has_processed = False @@ -426,6 +426,79 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3 assert isinstance(src, dace.nodes.AccessNode) +def test_parallel_fusion_simple(): + N1, N2 = 10, 20 + + def _make_sdfg(): + sdfg = dace.SDFG("simple_parallel_map") + state = sdfg.add_state("state", is_start_block=True) + for name in ("A", "B", "out1", "out2"): + sdfg.add_array(name, shape=(N1, N2), transient=False, dtype=dace.float64) + sdfg.add_scalar("dmr", dtype=dace.float64, transient=False) + A, B, dmr, out1, out2 = (state.add_access(name) for name in ("A", "B", "dmr", "out1", "out2")) + + _, map1_entry, _ = state.add_mapped_tasklet( + "map_with_dynamic_range", + map_ranges={"__i0": f"0:{N1}", "__i1": f"0:{N2}"}, + inputs={"__in0": dace.Memlet("A[__i0, __i1]")}, + code="__out = __in0 + dynamic_range_value", + outputs={"__out": dace.Memlet("out1[__i0, __i1]")}, + input_nodes={"A": A}, + output_nodes={"out1": out1}, + external_edges=True, + ) + state.add_edge( + dmr, + None, + map1_entry, + "dynamic_range_value", + dace.Memlet("dmr[0]"), + ) + map1_entry.add_in_connector("dynamic_range_value") + + _, map2_entry, _ = state.add_mapped_tasklet( + "map_without_dynamic_range", + map_ranges={"__i2": f"0:{N1}", "__i3": f"0:{N2}"}, + inputs={ + "__in0": dace.Memlet("A[__i2, __i3]"), + "__in1": dace.Memlet("B[__i2, __i3]") + }, + code="__out = __in0 + __in1", + outputs={"__out": dace.Memlet("out2[__i2, __i3]")}, + input_nodes={"A": A, "B": B}, + output_nodes={"out2": out2}, + external_edges=True, + ) + sdfg.validate() + return sdfg, map1_entry, map2_entry + + for mode in range(2): + A = np.random.rand(N1, N2) + B = np.random.rand(N1, N2) + dmr = 3.1415 + out1 = np.zeros_like(A) + out2 = np.zeros_like(B) + res1 = A + dmr + res2 = A + B + + sdfg, map1_entry, map2_entry = _make_sdfg() + + if mode: + map1_entry, map2_entry = map2_entry, map1_entry + + ParallelMapFusion.apply_to( + sdfg, + map_entry1=map1_entry, + map_entry2=map2_entry, + verify=True, + ) + assert count_node(sdfg, dace.sdfg.nodes.MapEntry) == 1 + + sdfg(A=A, B=B, dmr=dmr, out1=out1, out2=out2) + assert np.allclose(out1, res1) + assert np.allclose(out2, res2) + + if __name__ == '__main__': test_indirect_accesses() test_fusion_shared() @@ -439,5 +512,6 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3 test_fusion_with_empty_memlet() test_fusion_with_nested_sdfg_0() test_fusion_with_nested_sdfg_1() + test_parallel_fusion_simple() print("SUCCESS") From c23ed39cc72dcef4042f2f5555616c1cf553811c Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 30 Aug 2024 16:59:48 +0200 Subject: [PATCH 030/115] Fixed non proper cycle detection. There is still the dynamic memlet problem. --- .../dataflow/map_fusion_helper.py | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index fcc810aad3..e17dd021ec 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -604,20 +604,28 @@ def partition_first_outputs( # and the second map entry. found_second_map = False consumer_subsets: List[subsets.Subset] = [] - for intermediate_node_out_edges in state.out_edges(intermediate_node): - # Check if we have not reached the second map entry. - # This happens if the intermediate node is a shared node. - # However, we are only allowed to find the second map once. - if intermediate_node_out_edges.dst is not map_entry_2: - continue + for intermediate_node_out_edge in state.out_edges(intermediate_node): + # If we do not reach the second map immediately, we must make sure + # that we will never reach it otherwise we will create cycles. + if intermediate_node_out_edge.dst is not map_entry_2: + if all_nodes_between( + graph=state, + begin=intermediate_node_out_edge.dst, + end=map_entry_2, + ) is None: + continue + return None + + # The second map can only be reached once, because we only handle + # this case. if found_second_map: # TODO(phimuell): Lift this restriction. return None found_second_map = True - assert intermediate_node_out_edges.dst_conn.startswith("IN_") + assert intermediate_node_out_edge.dst_conn.startswith("IN_") consumer_subsets.extend( e.data.src_subset - for e in state.out_edges_by_connector(map_entry_2, "OUT_" + intermediate_node_out_edges.dst_conn[3:]) + for e in state.out_edges_by_connector(map_entry_2, "OUT_" + intermediate_node_out_edge.dst_conn[3:]) ) # The subsets are not set correctly, so we give up. if any(consumer_subset is None for consumer_subset in consumer_subsets): From dad61cbc8a7e1ea7724e02d49edc6c16749bcc2e Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 2 Sep 2024 08:14:13 +0200 Subject: [PATCH 031/115] Modified how the pre exit Memlet (the Memlet that writes in the new intermediate) is generated. Before the function the Memlet was created a new. However, it will now be copied and by that we aim to preserve as much information as possible. The new inner Memlet (the Memlet that goes from teh intermediate to the consumer) was already handled that way. --- .../dataflow/map_fusion_serial.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_serial.py b/dace/transformation/dataflow/map_fusion_serial.py index 6c2eacd31b..269e4e02be 100644 --- a/dace/transformation/dataflow/map_fusion_serial.py +++ b/dace/transformation/dataflow/map_fusion_serial.py @@ -638,6 +638,21 @@ def handle_intermediate_set( ) new_inter_node: nodes.AccessNode = state.add_access(new_inter_name) + # Get the subset that defined into which part of the old intermediate + # the old output edge wrote to. We need that to adjust the producer + # Memlets, since they now write into the new (smaller) intermediate. + assert pre_exit_edge.data.data == inter_name + assert pre_exit_edge.data.dst_subset is not None + old_pre_exit_edge_subset = pre_exit_edge.data.dst_subset + + # Memlets have a lot of additional informations, such as dynamic. + # To ensure that we get all of them, we will now copy them and modify + # the one that was originally there. We also hope that propagate will + # set the rest for us correctly. + new_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) + new_pre_exit_memlet.data = new_inter_name + new_pre_exit_memlet.dst_subset = subsets.Range.from_array(new_inter_desc) + # New we will reroute the output Memlet, thus it will no longer pass # through the Map exit but through the newly created intermediate. # NOTE: We will delete the previous edge later. @@ -646,16 +661,9 @@ def handle_intermediate_set( pre_exit_edge.src_conn, new_inter_node, None, - sdfg.make_array_memlet(new_inter_name), + new_pre_exit_memlet, ) - # Get the subset that defined into which part of the old intermediate - # the old output edge wrote to. We need that to adjust the producer - # Memlets, since they now write into the new (smaller) intermediate. - assert pre_exit_edge.data.data == inter_name - assert pre_exit_edge.data.dst_subset is not None - old_pre_exit_edge_subset = pre_exit_edge.data.dst_subset - # We now handle the MemletTree defined by this edge. # The newly created edge, only handled the last collection step. for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children(include_self=True): @@ -707,7 +715,7 @@ def handle_intermediate_set( # As for the producer side, we now read from a smaller array, # So we must offset them, we use the original edge for this. - assert inner_edge.data.src_subset is not None, f"{inner_edge} | {inner_edge.data} | {inner_edge.data.src_subset} | {inner_edge.data.dst_subset} " + assert inner_edge.data.src_subset is not None inner_edge_correction_offset = inner_edge.data.src_subset # Now we create a new connection that instead reads from the new From a57ebb3dcfa755d4a2cd261323302c8bee764191 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 2 Sep 2024 08:17:37 +0200 Subject: [PATCH 032/115] Modified how the partition function works. The function now examines the Melets that goes to the intermediate better. They now also test if the Memlet is dynamic. If such a memlet is discovered then it rejects the operation. --- .../dataflow/map_fusion_helper.py | 40 ++++++++++--------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index e17dd021ec..74ab51e080 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -583,12 +583,15 @@ def partition_first_outputs( # - The source of the producer can not be a view (we do not handle this) # - The edge shall also not be a reduction edge. # - Defined location to where they write. + # - No dynamic Melets. # Furthermore, we will also extract the subsets, i.e. the location they # modify inside the intermediate array. producer_subsets: List[subsets.Subset] = [] for producer_edge in producer_edges: if isinstance(producer_edge.src, nodes.AccessNode) and is_view(producer_edge.src, sdfg): return None + if producer_edge.data.dynamic: + return None if producer_edge.data.wcr is not None: return None if producer_edge.data.dst_subset is None: @@ -596,17 +599,15 @@ def partition_first_outputs( producer_subsets.append(producer_edge.data.dst_subset) # Now we determine the consumer of nodes. For this we are using the edges - # leaves the second map entry. We could find the final consumers (e.g. - # Tasklets), however, this might make problems as they depend also on - # symbols defined by nested maps. However, we are not interested in edges, - # but actually what they read, i.e. their source subset. - # In any case there can be at most one connection between the intermediate - # and the second map entry. + # leaves the second map entry. It is not necessary to find the actual + # consumer nodes, as they might depend on symbols of nested Maps. + # For the covering test we only need their subsets, but we will perform + # some scan and filtering on them. found_second_map = False consumer_subsets: List[subsets.Subset] = [] for intermediate_node_out_edge in state.out_edges(intermediate_node): - # If we do not reach the second map immediately, we must make sure - # that we will never reach it otherwise we will create cycles. + + # Ensure that there is no multihop connection to the second map entry. if intermediate_node_out_edge.dst is not map_entry_2: if all_nodes_between( graph=state, @@ -616,24 +617,25 @@ def partition_first_outputs( continue return None - # The second map can only be reached once, because we only handle - # this case. + # Ensure that the second map is found exactly once. if found_second_map: # TODO(phimuell): Lift this restriction. return None found_second_map = True + + # Now we look at all edges that leave the second map entry, as they + # define what is read inside the map. + # NOTE: The subset still uses the old iteration variables. assert intermediate_node_out_edge.dst_conn.startswith("IN_") - consumer_subsets.extend( - e.data.src_subset - for e in state.out_edges_by_connector(map_entry_2, "OUT_" + intermediate_node_out_edge.dst_conn[3:]) - ) - # The subsets are not set correctly, so we give up. - if any(consumer_subset is None for consumer_subset in consumer_subsets): - return None + for inner_consumer_edge in state.out_edges_by_connector(map_entry_2, "OUT_" + intermediate_node_out_edge.dst_conn[3:]): + if inner_consumer_edge.data.src_subset is None: + return None + if inner_consumer_edge.data.dynamic: + return None + consumer_subsets.append(inner_consumer_edge.data.src_subset) assert len(consumer_subsets) != 0 - # Furthermore, the consumer still uses the original symbols of the - # second map, so we must rename them. + # The consumer still uses the original symbols of the second map, so we must rename them. if repl_dict: consumer_subsets = copy.deepcopy(consumer_subsets) for consumer_subset in consumer_subsets: From 439d6f3ea5b6737cb9847fa1995181b9ac9efd0b Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 2 Sep 2024 09:05:37 +0200 Subject: [PATCH 033/115] Modified the map fusion tests. The nested SDFG tests generate dynamic edges, which we do not handle. To ensure that they apply these edges are made non dynamic. This is safe for this case. --- tests/transformations/mapfusion_test.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 4236cd6a46..748e69dc3d 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -381,6 +381,12 @@ def fusion_with_nested_sdfg_0(A: dace.int32[10], B: dace.int32[10], C: dace.int3 A[i] = tmp[i] * 2 sdfg = fusion_with_nested_sdfg_0.to_sdfg(simplify=True) + + # Because the transformation refuses to fuse dynamic edges. + # We have to eliminate them. + for state in sdfg.states(): + for edge in state.edges(): + edge.data.dynamic = False apply_fusion(sdfg) for sd in sdfg.all_sdfgs_recursive(): @@ -411,6 +417,12 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3 B[i] = tmp[i] * 2 sdfg = fusion_with_nested_sdfg_1.to_sdfg(simplify=True) + + # Because the transformation refuses to fuse dynamic edges. + # We have to eliminate them. + for state in sdfg.states(): + for edge in state.edges(): + edge.data.dynamic = False apply_fusion(sdfg) if len(sdfg.states()) != 1: From 13c80ec57ca745c1825a4d6a0944bbb792018587 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 2 Sep 2024 12:53:19 +0200 Subject: [PATCH 034/115] More towards bug compatibility. Be a bit more restrictive what a shared intermediate is. If an intermediate node is used somewehere else (or rather its data) then the transformation will not fuse it. Note that this is not a bug in the MapFusion transformation, but in the auto optimizer (most likely SubgraphFusion`). This commit essentially adds a feature to disable some improvements of the transformation such that auto optimizer works. To see this use `tests/npbench/polybench/correlation_test.py`. Use the MapFusion alone and you will see it works. Then use auto optimization and you will see that it does not work. --- .../dataflow/map_fusion_helper.py | 90 ++++++++++++++++--- 1 file changed, 77 insertions(+), 13 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index 74ab51e080..ae3fadd7b8 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -28,18 +28,17 @@ class MapFusionHelper(transformation.SingleStateTransformation): Args: only_inner_maps: Only match Maps that are internal, i.e. inside another Map. only_toplevel_maps: Only consider Maps that are at the top. + ssa_sdfg: If `True` assumes that the SDFG is in SSA style, this will skip some checks. """ only_toplevel_maps = properties.Property( dtype=bool, default=False, - allow_none=False, desc="Only perform fusing if the Maps are in the top level.", ) only_inner_maps = properties.Property( dtype=bool, default=False, - allow_none=False, desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", ) shared_data = properties.DictProperty( @@ -49,14 +48,22 @@ class MapFusionHelper(transformation.SingleStateTransformation): allow_none=True, optional=True, # Do not serialize. optional_condition=lambda _: False, - desc="Maps SDFGs to the set of data that can not be removed. " - "The variable acts as a cache, and is managed by 'is_shared_data()'.", + desc="Maps SDFGs to the set of data that can not be removed," + " because they transmit data _between states_, such data will be made 'shared'." + " This variable acts as a cache, and is managed by 'is_shared_data()'.", + ) + ssa_sdfg = properties.Property( + dtype=bool, + default=False, + desc="If `True` then the transformation assumes the SDFG uses SSA style assignments", ) + def __init__( self, only_inner_maps: Optional[bool] = None, only_toplevel_maps: Optional[bool] = None, + ssa_sdfg: Optional[bool] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -64,6 +71,8 @@ def __init__( self.only_toplevel_maps = bool(only_toplevel_maps) if only_inner_maps is not None: self.only_inner_maps = bool(only_inner_maps) + if ssa_sdfg is not None: + self.ssa_sdfg = bool(ssa_sdfg) self.shared_data = {} @@ -436,6 +445,45 @@ def _compute_shared_data( self.shared_data[sdfg] = shared_data + def _compute_multi_write_data( + self, + state: SDFGState, + sdfg: SDFG, + ) -> Set[str]: + """Computes data inside a _single_ state, that is written multiple times. + + Essentially this function computes the set of data that does not follow + the single static assignment idiom. The function also resolves views. + If an access node that refers to a view, the function will add not only + the view itself, but also the data it refers to. + + Args: + state: The state that should be examined. + sdfg: The SDFG object. + + Note: + This information is used by the partition function, if it is legal to turn + a intermediate node into shared output or if the partition does not exists + at all. The current implementation is rather simple as it only checks if + a data is written to multiple times in the same state. + Actually everything could be turned into a shared output, however, some + DaCe transformation fail to proper examine the graph and detect these cases. + """ + data_written_to: Set[str] = set() + multi_write_data: Set[str] = set() + + for access_node in state.data_nodes(): + if state.in_degree(access_node) == 0: + continue + if is_view(access_node, sdfg): + # This is an over approximation. + multi_write_data.update([access_node.data, track_view(access_node, state, sdfg).data]) + elif access_node.data in data_written_to: + multi_write_data.add(access_node.data) + data_written_to.add(access_node.data) + return multi_write_data + + def partition_first_outputs( self, state: SDFGState, @@ -501,6 +549,19 @@ def partition_first_outputs( # Set of intermediate nodes that we have already processed. processed_inter_nodes: Set[nodes.Node] = set() + # These are the data that is written to multiple times in this state. + # If a data is written to multiple time in a state, it could be + # classified as shared. A problem could be if, this shared node happens to + # then have zero out degree, thus dependencies are given by the edges that + # leave the second exit node and not by the output nodes of the intermediate + # node. Because some other DaCe transformation (auto optimizer) fail to + # take this into account properly they do transformations that are invalid. + # Thus we will never modify such intermediate nodes. + if not self.ssa_sdfg: + multi_write_data: Set[str] = self._compute_multi_write_data(state, sdfg) + else: + multi_write_data = set() + # Now scan all output edges of the first exit and classify them for out_edge in state.out_edges(map_exit_1): intermediate_node: nodes.Node = out_edge.dst @@ -514,6 +575,7 @@ def partition_first_outputs( # The intermediate can only have one incoming degree. It might be possible # to handle multiple incoming edges, if they all come from the top map. # However, the resulting SDFG might be invalid. + # NOTE: If needed the output degree is changed further down. if state.in_degree(intermediate_node) != 1: return None @@ -538,6 +600,16 @@ def partition_first_outputs( # cases, as handling them is essentially rerouting an edge, whereas # handling intermediate nodes is much more complicated. + # Checks if the intermediate node refers to data that is accessed by + # _other_ access nodes in _this_ state. If this is the case then never + # touch this intermediate node. + # TODO(phimuell): Technically it would be enough to turn the node into + # a shared output node, because this will still fulfil the dependencies. + # However, some DaCe transformation can not handle this properly, so we + # are _forced_ to reject this node. + if intermediate_node.data in multi_write_data: + return None + # If `downstream_nodes` is empty, this means that the second map entry # was found immediately, we only allow the case that there is one # connecting Memlet. @@ -560,15 +632,6 @@ def partition_first_outputs( if out_edge.data.is_empty(): return None - # The intermediate now can only have a single source. It might be possible - # to extend this to many inputs as long as they come from the top map. - # NOTE: The output degree is checked implicitly further down, the - # general rule is, that multiple outputs are only allowed if only - # one enters the second Map, the other output must go to different - # consumers, in which case the node is a shared intermediate. - if state.in_degree(intermediate_node) != 1: - return None - # It can happen that multiple edges converges at the `IN_` connector # of the first map exit, but there is only one edge leaving the exit. # It is complicate to handle this, so for now we ignore it. @@ -631,6 +694,7 @@ def partition_first_outputs( if inner_consumer_edge.data.src_subset is None: return None if inner_consumer_edge.data.dynamic: + # TODO(phimuell): Is this restriction necessary, I am not sure. return None consumer_subsets.append(inner_consumer_edge.data.src_subset) assert len(consumer_subsets) != 0 From 12a5cf7cc79b724a246acd362c38ada6b9fdd712 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 2 Sep 2024 13:15:47 +0200 Subject: [PATCH 035/115] The `ssa_sdfg` parameter was noit named properly. I have now changed its name to `strict_dataflow` which is a bit more informative. --- .../dataflow/map_fusion_helper.py | 31 +++++++++++++------ 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index ae3fadd7b8..11713a52ea 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -28,7 +28,15 @@ class MapFusionHelper(transformation.SingleStateTransformation): Args: only_inner_maps: Only match Maps that are internal, i.e. inside another Map. only_toplevel_maps: Only consider Maps that are at the top. - ssa_sdfg: If `True` assumes that the SDFG is in SSA style, this will skip some checks. + strict_dataflow: If `True`, the default, the transformation ensures strict + data flow. + + Note: + `strict_dataflow` only has an influence if there is a downstream connection + from one access node to another that both _write_ to the same data. + Technically it is possible to turn the upstream access node into a shared output, + see `partition_first_outputs()`, data dependency is still guaranteed by the + maps. However, some other DaCe transformation cannot properly handle this case. """ only_toplevel_maps = properties.Property( @@ -52,10 +60,10 @@ class MapFusionHelper(transformation.SingleStateTransformation): " because they transmit data _between states_, such data will be made 'shared'." " This variable acts as a cache, and is managed by 'is_shared_data()'.", ) - ssa_sdfg = properties.Property( + strict_dataflow = properties.Property( dtype=bool, - default=False, - desc="If `True` then the transformation assumes the SDFG uses SSA style assignments", + default=True, + desc="If `False` then the transformation will not preserve strict data flow.", ) @@ -63,7 +71,7 @@ def __init__( self, only_inner_maps: Optional[bool] = None, only_toplevel_maps: Optional[bool] = None, - ssa_sdfg: Optional[bool] = None, + strict_dataflow: Optional[bool] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -71,8 +79,8 @@ def __init__( self.only_toplevel_maps = bool(only_toplevel_maps) if only_inner_maps is not None: self.only_inner_maps = bool(only_inner_maps) - if ssa_sdfg is not None: - self.ssa_sdfg = bool(ssa_sdfg) + if strict_dataflow is not None: + self.strict_dataflow = bool(strict_dataflow) self.shared_data = {} @@ -465,9 +473,12 @@ def _compute_multi_write_data( This information is used by the partition function, if it is legal to turn a intermediate node into shared output or if the partition does not exists at all. The current implementation is rather simple as it only checks if - a data is written to multiple times in the same state. + a data is written to multiple times in the same state. A more refined + (but still simple) implementation would take the location of the access + node into consideration. Actually everything could be turned into a shared output, however, some - DaCe transformation fail to proper examine the graph and detect these cases. + DaCe transformation fail to proper examine the graph in this case and + perform modifications that lead to wrong behaviour. """ data_written_to: Set[str] = set() multi_write_data: Set[str] = set() @@ -557,7 +568,7 @@ def partition_first_outputs( # node. Because some other DaCe transformation (auto optimizer) fail to # take this into account properly they do transformations that are invalid. # Thus we will never modify such intermediate nodes. - if not self.ssa_sdfg: + if self.strict_dataflow: multi_write_data: Set[str] = self._compute_multi_write_data(state, sdfg) else: multi_write_data = set() From 995ef4dbeaccafdc4d80e52a65efd390684ea87a Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 2 Sep 2024 13:21:06 +0200 Subject: [PATCH 036/115] Auto optimize now uses map fusion with a strict dataflow. This is done to work around some wrong behaviour in composite fusion. --- dace/transformation/auto/auto_optimize.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/dace/transformation/auto/auto_optimize.py b/dace/transformation/auto/auto_optimize.py index 7bced3bec9..09ff481e39 100644 --- a/dace/transformation/auto/auto_optimize.py +++ b/dace/transformation/auto/auto_optimize.py @@ -58,7 +58,12 @@ def greedy_fuse(graph_or_subgraph: GraphViewType, # If we have an SDFG, recurse into graphs graph_or_subgraph.simplify(validate_all=validate_all) # MapFusion for trivial cases - graph_or_subgraph.apply_transformations_repeated(MapFusion, validate_all=validate_all) + # We have to use `strict_dataflow` because it is known that `CompositeFusion` + # has problems otherwise. + graph_or_subgraph.apply_transformations_repeated( + MapFusion(strict_dataflow=True), + validate_all=validate_all, + ) # recurse into graphs for graph in graph_or_subgraph.nodes(): @@ -76,7 +81,10 @@ def greedy_fuse(graph_or_subgraph: GraphViewType, sdfg, graph, subgraph = None, None, None if isinstance(graph_or_subgraph, SDFGState): sdfg = graph_or_subgraph.parent - sdfg.apply_transformations_repeated(MapFusion, validate_all=validate_all) + sdfg.apply_transformations_repeated( + MapFusion(strict_dataflow=True), + validate_all=validate_all, + ) graph = graph_or_subgraph subgraph = SubgraphView(graph, graph.nodes()) else: From dc4ed31b2f89287b146d2dfb1e7208d3c42c95df Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 2 Sep 2024 13:39:03 +0200 Subject: [PATCH 037/115] Fixed the classification function for shared transient. --- .../dataflow/map_fusion_helper.py | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index 11713a52ea..d264b80065 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -378,7 +378,6 @@ def is_shared_data( serial map fusion. The function determine this properties, according to the following rules: - - The access node must be in the top scope. - The underlying data is global. - The `data` descriptor is used multiple times with the same state. - `data` has an out or in degree of zero. @@ -423,24 +422,22 @@ def _compute_shared_data( # We go through all states and classify the nodes, according to the rules. prevously_seen_data: Set[str] = set() for state in sdfg.nodes(): - scope_dict = state.scope_dict() for access_node in state.data_nodes(): - if scope_dict[access_node] is not None: - # We are only interested in global data. - pass - elif access_node.data in shared_data: - # The data was already determined to be shared data + if access_node.data in shared_data: + # The data was already classified to be shared data pass elif access_node.data in prevously_seen_data: # We have seen this data before, either in this state or in - # a previous one, but we did not classifies it as shared, - # let's do this now. Note that we do not remove the data - # also from `previously_seen_data`. + # a previous one, but we did not classifies it as shared back then + # Note that we do not remove the data also from `previously_seen_data`. shared_data.add(access_node.data) - elif state.out_degree(access_node) == 0: - # Sink and source nodes also have to be kept. + if state.in_degree(access_node) == 0: + # (Transient) sink nodes are used in other states, or simplify + # will get rid of them. shared_data.add(access_node.data) - elif state.in_degree(access_node) == 0: + elif state.out_degree(access_node) != 0: # state.out_degree() == 0 or state.out_degree() > 1 + # The node is either a source node (it is shared in another state). + # Output degree of more than one, means it is used in another state. shared_data.add(access_node.data) else: # The node was not classified as shared data, so we record that From eb48391a3c2ad83839e8c554ad642aa31c108c2b Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 2 Sep 2024 14:52:07 +0200 Subject: [PATCH 038/115] Ensure that we have one state in the fusion test. I think that this is the reason. --- tests/transformations/mapfusion_test.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 748e69dc3d..d2b12347ff 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -174,7 +174,7 @@ def fusion_indirect_access(A: dace.float32[100], B: dace.float32[100], idx: dace def test_fusion_simple(): - sdfg = fusion_simple.to_sdfg() + sdfg = fusion_simple.to_sdfg(simplify=True) sdfg = apply_fusion(sdfg, final_maps=1) A = np.random.rand(10, 20).astype(np.float32) @@ -188,7 +188,7 @@ def test_fusion_simple(): def test_fusion_rename(): - sdfg = fusion_rename.to_sdfg() + sdfg = fusion_rename.to_sdfg(simplify=True) sdfg = apply_fusion(sdfg, final_maps=1) A = np.random.rand(10, 20).astype(np.float32) @@ -202,7 +202,7 @@ def test_fusion_rename(): def test_fusion_shared(): - sdfg = fusion_shared_output.to_sdfg() + sdfg = fusion_shared_output.to_sdfg(simplify=True) sdfg = apply_fusion(sdfg) A = np.random.rand(10, 20).astype(np.float32) @@ -218,7 +218,7 @@ def test_fusion_shared(): def test_indirect_accesses(): - sdfg = fusion_indirect_access.to_sdfg() + sdfg = fusion_indirect_access.to_sdfg(simplify=True) sdfg = apply_fusion(sdfg, final_maps=2) A = np.random.rand(100).astype(np.float32) @@ -233,7 +233,7 @@ def test_indirect_accesses(): def test_multiple_fusions(): - sdfg = multiple_fusions.to_sdfg() + sdfg = multiple_fusions.to_sdfg(simplify=True) sdfg.save(os.path.join('_dacegraphs', 'before2.sdfg')) sdfg.simplify() @@ -254,7 +254,7 @@ def test_multiple_fusions(): def test_fusion_chain(): - sdfg = fusion_chain.to_sdfg() + sdfg = fusion_chain.to_sdfg(simplify=True) sdfg.simplify() sdfg = apply_fusion(sdfg, final_maps=1) @@ -270,7 +270,7 @@ def test_fusion_chain(): def test_fusion_with_transient(): A = np.random.rand(2, 20) expected = A * A * 2 - sdfg = fusion_with_transient.to_sdfg() + sdfg = fusion_with_transient.to_sdfg(simplify=True) sdfg.simplify() sdfg = apply_fusion(sdfg, removed_maps=2) From d53302d3cc250a817c8248256f4e9368c67b6789 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 2 Sep 2024 17:08:46 +0200 Subject: [PATCH 039/115] Fixed a problem in the classification of the shared transients. It was a symple typo. --- dace/transformation/dataflow/map_fusion_helper.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index d264b80065..986fb00d60 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -435,9 +435,10 @@ def _compute_shared_data( # (Transient) sink nodes are used in other states, or simplify # will get rid of them. shared_data.add(access_node.data) - elif state.out_degree(access_node) != 0: # state.out_degree() == 0 or state.out_degree() > 1 - # The node is either a source node (it is shared in another state). - # Output degree of more than one, means it is used in another state. + elif state.out_degree(access_node) != 1: # state.out_degree() == 0 or state.out_degree() > 1 + # The access node is either a source node (it is shared in another + # state) or the node has a degree larger than one, so it is used + # in this state somewhere else. shared_data.add(access_node.data) else: # The node was not classified as shared data, so we record that From 3f54e1f807ab44f140a92a39448ca4cc69b8a33f Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 2 Sep 2024 17:09:43 +0200 Subject: [PATCH 040/115] Updated the rw conflict code a little bit. It now also checks if two different views refer to the same data. --- .../dataflow/map_fusion_serial.py | 36 ++++++++++--------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_serial.py b/dace/transformation/dataflow/map_fusion_serial.py index 269e4e02be..c9d7fcde55 100644 --- a/dace/transformation/dataflow/map_fusion_serial.py +++ b/dace/transformation/dataflow/map_fusion_serial.py @@ -125,24 +125,22 @@ def has_read_write_dependency( state: SDFGState, sdfg: SDFG, ) -> bool: - """Test if there is a read write dependency between the two maps. + """Test if there is a read write dependency between the two maps to be fused. - The function checks if the first map does not read anything from - a data descriptor, the second map writes into. + The function first looks at the set of data that is read/written by the + two maps. If the function detects a possible conflict, the function will + evaluate the subsets of the read and write to determine if the conflict + can be resolved or not. Returns: - `True` if there is a conflict between input and outputs, `False` - if there is no conflict. + `True` if there is a conflict between the maps that can not be handled. + If there is no conflict or if the conflict can be handled then `False` + is returned. Args: map_entry_1: The entry node of the first map. map_entry_2: The entry node of the second map. state: The state on which we operate. - - Note: - The current implementation just computes the set of data that is - used as input and for output. If there is an intersection then - the function considers this as a read write conflict. """ map_exit_1: nodes.MapExit = state.exit_node(map_entry_1) map_exit_2: nodes.MapExit = state.exit_node(map_entry_2) @@ -168,15 +166,19 @@ def has_read_write_dependency( }) real_read_map_1, real_write_map_1, real_read_map_2, real_write_map_2 = resolved_sets - # We now test for "structural problems", i.e. problems where the resulting - # SDFG would be invalid, all of these cases are characterized by the fact - # that both maps write to the same data. This is hard or impossible to - # handle, so we forbid all these cases. + # If the resolved and the unresolved set of input/output have different lengths, + # it means that there were two different views that ultimately referred to the + # same data. + for unresolved_access, resolved_access in zip(access_sets, resolved_sets): + if len(unresolved_access) != len(resolved_access): + return True + + # We do not allow that the first and second map each write to the same data. + # The reason is because it is very hard to handle correctly. if not real_write_map_1.isdisjoint(real_write_map_2): return True - # We will now test if there are no conflicts, for this we require that all - # input is distinct from the all the output. + # The inputs and outputs are different, so there can not be any conflicts. # Must be done after the test above! if (real_read_map_1 | real_read_map_2).isdisjoint(real_write_map_1 | real_write_map_2): return False @@ -364,6 +366,8 @@ def _test_if_subsets_are_point_wise( If a series of subsets are point wise it means that all Memlets, access the same data. This is an important property because the whole map fusion is build upon this. + If the subsets originates from different maps, then they must have been + renamed. Args: subsets_to_check: The list of subsets that should be checked. From 0f46c1c2d36248520470bcf81aaf6f25ec670501 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 4 Sep 2024 13:42:49 +0200 Subject: [PATCH 041/115] Updated the `SDFGState._read_and_write_set()` function. In some cases the `dst_subset` property of a Memlet is not set for example this happens if the edge goes to a Tasklet, in this case the value is None. --- dace/sdfg/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 1428564f4e..08d5229c7d 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -763,7 +763,7 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, for out_edge in list(out_edges): for in_edge in list(in_edges): if (in_edge.data.data == out_edge.data.data - and in_edge.data.dst_subset.covers(out_edge.data.src_subset)): + and (in_edge.data.dst_subset is not None and in_edge.data.dst_subset.covers(out_edge.data.src_subset))): out_edges.remove(out_edge) break From 05851127ece69e33644b2c960004dc343a6a73bb Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 4 Sep 2024 14:06:10 +0200 Subject: [PATCH 042/115] Extend the range of `strict_dataflow`. The taransformation will ensure that dimensions of length 1, which are created by overapproximation, of the new inner intermediate are removed. However, in conjunction with auto optimize this might lead to problems, thus the strict dataflow mode will also stop this behaviour. I observe this behaviour in `tests/npbench/weather_stencils/vadv_test.py`. However, it is not caused by auto optimizer directly but by redundant array removal. However, I But I have never seen this behaviour inmy tests. --- .../dataflow/map_fusion_serial.py | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_serial.py b/dace/transformation/dataflow/map_fusion_serial.py index c9d7fcde55..26e14bd5af 100644 --- a/dace/transformation/dataflow/map_fusion_serial.py +++ b/dace/transformation/dataflow/map_fusion_serial.py @@ -596,20 +596,27 @@ def handle_intermediate_set( pre_exit_edge = pre_exit_edges[0] new_inter_shape_raw = symbolic.overapproximate(pre_exit_edge.data.subset.size()) - # Over approximation will leave us with some unneeded size one dimensions. - # That are known to cause some troubles, so we will now remove them. - squeezed_dims: List[int] = [] # These are the dimensions we removed. - new_inter_shape: List[int] = [] # This is the final shape of the new intermediate. - for dim, (proposed_dim_size, full_dim_size) in enumerate( - zip(new_inter_shape_raw, inter_shape) - ): - # Order of checks is important! - if full_dim_size == 1: # Must be kept! - new_inter_shape.append(proposed_dim_size) - elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. - squeezed_dims.append(dim) - else: - new_inter_shape.append(proposed_dim_size) + if not self.strict_dataflow: + # Over approximation will leave us with some unneeded size one dimensions. + # If they are removed some dace transformations (especially auto optimization) + # will have problems. + squeezed_dims: List[int] = [] # These are the dimensions we removed. + new_inter_shape: List[int] = [] # This is the final shape of the new intermediate. + for dim, (proposed_dim_size, full_dim_size) in enumerate( + zip(new_inter_shape_raw, inter_shape) + ): + # Order of checks is important! + if full_dim_size == 1: # Must be kept! + new_inter_shape.append(proposed_dim_size) + elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. + squeezed_dims.append(dim) + else: + new_inter_shape.append(proposed_dim_size) + else: + squeezed_dims = [] + new_inter_shape = list(new_inter_shape_raw) + + # This is the name of the new "intermediate" node that we will create. # It will only have the shape `new_inter_shape` which is basically its From 61160e4e1f7757437b5b9faf387e75ea91fa722a Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 4 Sep 2024 14:10:02 +0200 Subject: [PATCH 043/115] Fixed a bug in the shared mode handling. The output memlet's source set was not set correctly. --- dace/transformation/dataflow/map_fusion_serial.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_serial.py b/dace/transformation/dataflow/map_fusion_serial.py index 26e14bd5af..cccf2639c6 100644 --- a/dace/transformation/dataflow/map_fusion_serial.py +++ b/dace/transformation/dataflow/map_fusion_serial.py @@ -796,12 +796,9 @@ def handle_intermediate_set( # temporary node to the Map output. This will essentially restore # or preserve the output for the intermediate node. It is important # that we use the data that `preExitEdge` was used. - new_exit_memlet = copy.deepcopy(pre_exit_edge.data) - assert new_exit_memlet.data == inter_name - new_exit_memlet.subset = pre_exit_edge.data.dst_subset - new_exit_memlet.other_subset = ( - "0" if is_scalar else subsets.Range.from_array(inter_desc) - ) + final_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) + assert pre_exit_edge.data.data == inter_name + final_pre_exit_memlet.other_subset = subsets.Range.from_array(new_inter_desc) new_pre_exit_conn = map_exit_2.next_connector() state.add_edge( @@ -809,7 +806,7 @@ def handle_intermediate_set( None, map_exit_2, "IN_" + new_pre_exit_conn, - new_exit_memlet, + final_pre_exit_memlet, ) state.add_edge( map_exit_2, From 584635a3a7951cfb97b1215e72946dffa1715fc8 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 4 Sep 2024 14:11:13 +0200 Subject: [PATCH 044/115] Misc changes to the serial fusion. --- dace/transformation/dataflow/map_fusion_serial.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dace/transformation/dataflow/map_fusion_serial.py b/dace/transformation/dataflow/map_fusion_serial.py index cccf2639c6..8bab1b6007 100644 --- a/dace/transformation/dataflow/map_fusion_serial.py +++ b/dace/transformation/dataflow/map_fusion_serial.py @@ -545,8 +545,8 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non map_exit_2.map = map_entry_1.map - @staticmethod def handle_intermediate_set( + self, intermediate_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]], state: SDFGState, sdfg: SDFG, @@ -777,6 +777,7 @@ def handle_intermediate_set( if is_exclusive_set: # In exclusive mode the old intermediate node is no longer needed. + # This will also remove `out_edge` from the SDFG. assert state.degree(inter_node) == 1 state.remove_edge_and_connectors(out_edge) state.remove_node(inter_node) From b2dea1d1763627539043ca1b47341d2204fcbfc7 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 4 Sep 2024 14:57:23 +0200 Subject: [PATCH 045/115] Fixed a wrong test. --- tests/python_frontend/fields_and_global_arrays_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python_frontend/fields_and_global_arrays_test.py b/tests/python_frontend/fields_and_global_arrays_test.py index b7f5e46ee9..03cb4c5915 100644 --- a/tests/python_frontend/fields_and_global_arrays_test.py +++ b/tests/python_frontend/fields_and_global_arrays_test.py @@ -585,7 +585,7 @@ def caller(): # Ensure only three globals are created sdfg = caller.to_sdfg() - assert len([k for k in sdfg.arrays if '__g' in k]) == 3 + assert len([k for k in sdfg.arrays if k.startswith('__g')]) == 3 def test_two_inner_methods(): From fba668228b3b77700e0c5f11882f8953b80f50df Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 4 Sep 2024 14:59:29 +0200 Subject: [PATCH 046/115] Removed the strange nested SDFG check. --- .../dataflow/map_fusion_helper.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index 986fb00d60..c42b47730f 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -132,9 +132,6 @@ def can_be_fused( elif self.only_toplevel_maps: if scope[map_entry_1] is not None: return False - # TODO(phimuell): Figuring out why this is here. - elif is_nested_sdfg(sdfg): - return False # We will now check if there exists a remapping that of the map parameter if self.find_parameter_remapping(first_map=map_entry_1.map, second_map=map_entry_2.map) is None: @@ -744,22 +741,6 @@ def partition_first_outputs( return (pure_outputs, exclusive_outputs, shared_outputs) -def is_nested_sdfg( - sdfg: Union[dace.SDFG, dace.SDFGState, nodes.NestedSDFG], -) -> bool: - """Tests if `sdfg` is a NestedSDFG.""" - if isinstance(sdfg, dace.SDFGState): - sdfg = sdfg.parent - if isinstance(sdfg, nodes.NestedSDFG): - return True - elif isinstance(sdfg, dace.SDFG): - if sdfg.parent_nsdfg_node is not None: - return True - return False - else: - raise TypeError(f"Does not know how to handle '{type(sdfg).__name__}'.") - - def all_nodes_between( graph: Union[dace.SDFG, dace.SDFGState], begin: nodes.Node, From 62b028808f8bcc2f4aae9b410df60671e57927a9 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 4 Sep 2024 16:54:59 +0200 Subject: [PATCH 047/115] Made some modifications to the map fusion stuff. --- .../dataflow/map_fusion_helper.py | 484 +++++++----------- .../dataflow/map_fusion_parallel.py | 49 +- .../dataflow/map_fusion_serial.py | 10 +- 3 files changed, 233 insertions(+), 310 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index c42b47730f..a9247bf591 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -16,27 +16,21 @@ @properties.make_properties class MapFusionHelper(transformation.SingleStateTransformation): - """Contains common part of the fusion for parallel and serial Map fusion. - - The transformation assumes that the SDFG obeys the principals outlined [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). - The main advantage of this structure is, that it is rather easy to determine - if a transient is used anywhere else. This check, performed by - `is_shared_data()`. It is further speeded up by cashing some computation, - thus such an object should not be used after interstate optimizations were applied - to the SDFG. + """Common parts of the parallel and serial map fusion transformation. Args: only_inner_maps: Only match Maps that are internal, i.e. inside another Map. only_toplevel_maps: Only consider Maps that are at the top. - strict_dataflow: If `True`, the default, the transformation ensures strict - data flow. + strict_dataflow: If `True`, the default, the transformation ensures a more + stricter version of the data flow. Note: - `strict_dataflow` only has an influence if there is a downstream connection - from one access node to another that both _write_ to the same data. - Technically it is possible to turn the upstream access node into a shared output, - see `partition_first_outputs()`, data dependency is still guaranteed by the - maps. However, some other DaCe transformation cannot properly handle this case. + If `strict_dataflow` mode is enabled then the transformation will not remove + _direct_ data flow dependency from the graph. Furthermore, the transformation + will not remove size 1 dimensions of intermediate it crates. + This is a compatibility mode, that will limit the applicability of the + transformation, but might help transformations that does not fully analyse + the graph. """ only_toplevel_maps = properties.Property( @@ -63,7 +57,7 @@ class MapFusionHelper(transformation.SingleStateTransformation): strict_dataflow = properties.Property( dtype=bool, default=True, - desc="If `False` then the transformation will not preserve strict data flow.", + desc="If `True` then the transformation will ensure a more stricter data flow.", ) @@ -140,8 +134,8 @@ def can_be_fused( return True - @staticmethod def relocate_nodes( + self, from_node: Union[nodes.MapExit, nodes.MapEntry], to_node: Union[nodes.MapExit, nodes.MapEntry], state: SDFGState, @@ -236,8 +230,8 @@ def relocate_nodes( assert len(from_node.out_connectors) == 0 - @staticmethod def find_parameter_remapping( + self, first_map: nodes.Map, second_map: nodes.Map ) -> Union[Dict[str, str], None]: @@ -370,27 +364,17 @@ def is_shared_data( ) -> bool: """Tests if `data` is interstate data, an can not be removed. - Interstate data is used to transmit data between multiple state or - by extension within the state, and thus can not be removed by the - serial map fusion. - - The function determine this properties, according to the following rules: - - The underlying data is global. - - The `data` descriptor is used multiple times with the same state. - - `data` has an out or in degree of zero. - - The underlying data is referred to in another state. - - The function computes this information and then caches it for later use. + Interstate data is used to transmit data between multiple state or by + extension within the state. Thus it must be classified as a shared output. + This function will go through the SDFG to and collect the names of all data + container that should be classified as shared. Args: transient: The transient that should be checked. sdfg: The SDFG containing the array. Note: - - This function does not inspect the interstate edges, instead the - set of data that is accessed in interstate edges is approximated - with the set of sink nodes. - - This function works best if the SDFG uses SSA style. + The function computes the this set once for every SDFG and then caches it. """ if sdfg not in self.shared_data: self._compute_shared_data(sdfg) @@ -401,7 +385,7 @@ def _compute_shared_data( self, sdfg: dace.SDFG, ) -> None: - """This function computes the set of shared data for SDFG `sdfg`. + """Updates the internal set of shared data/interstate data of `self` for `sdfg`. See the documentation for `self.is_shared_data()` for a description. @@ -411,32 +395,49 @@ def _compute_shared_data( # Shared data of this SDFG. shared_data: Set[str] = set() - # Add all global data. + # All global data can not be removed, so it must always be shared. for data_name, data_desc in sdfg.arrays.items(): if not data_desc.transient: shared_data.add(data_name) + elif isinstance(data_desc, dace.data.Scalar): + shared_data.add(data_name) - # We go through all states and classify the nodes, according to the rules. + # We go through all states and classify the nodes/data: + # - Data is referred to in different states. + # - The access node is a view (both have to survive). + # - Transient sink or source node. + # - The access node has output degree larger than 1 (input degrees larger + # than one, will always be partitioned as shared anyway). prevously_seen_data: Set[str] = set() for state in sdfg.nodes(): for access_node in state.data_nodes(): + if access_node.data in shared_data: # The data was already classified to be shared data pass + elif access_node.data in prevously_seen_data: # We have seen this data before, either in this state or in # a previous one, but we did not classifies it as shared back then - # Note that we do not remove the data also from `previously_seen_data`. shared_data.add(access_node.data) + if state.in_degree(access_node) == 0: # (Transient) sink nodes are used in other states, or simplify # will get rid of them. shared_data.add(access_node.data) + elif state.out_degree(access_node) != 1: # state.out_degree() == 0 or state.out_degree() > 1 # The access node is either a source node (it is shared in another # state) or the node has a degree larger than one, so it is used # in this state somewhere else. shared_data.add(access_node.data) + + elif self.is_view(node=access_node, sdfg=sdfg): + # To ensure that the write to the view happens, both have to be shared. + viewed_data: str = self.track_view(view=access_node, state=state, sdfg=sdfg).data + shared_data.update([access_node.data, viewed_data]) + prevously_seen_data.update([access_node.data, viewed_data]) + else: # The node was not classified as shared data, so we record that # we saw it. Note that a node that was immediately classified @@ -444,6 +445,14 @@ def _compute_shared_data( # that was found twice will be inside this list. prevously_seen_data.add(access_node.data) + # Now we are collecting all symbols that interstate edges read from. + interstate_read_symbols: Set[str] = set() + for edge in sdfg.edges(): + interstate_read_symbols.update(edge.read_symbols()) + + # We also have to keep everything the edges referrers to and is an array. + shared_data.update(interstate_read_symbols.intersection(prevously_seen_data)) + # Update the internal cache self.shared_data[sdfg] = shared_data @@ -465,15 +474,13 @@ def _compute_multi_write_data( sdfg: The SDFG object. Note: - This information is used by the partition function, if it is legal to turn - a intermediate node into shared output or if the partition does not exists - at all. The current implementation is rather simple as it only checks if - a data is written to multiple times in the same state. A more refined - (but still simple) implementation would take the location of the access - node into consideration. - Actually everything could be turned into a shared output, however, some - DaCe transformation fail to proper examine the graph in this case and - perform modifications that lead to wrong behaviour. + This information is used by the partition function (in case strict data + flow mode is enabled), if it is legal to turn a intermediate node into + shared output or if the partition does not exists at all. The current + implementation is rather simple as it only checks if a data is written + to multiple times in the same state. A more refined (but still simple) + implementation would take the location of the access node into + consideration. """ data_written_to: Set[str] = set() multi_write_data: Set[str] = set() @@ -481,7 +488,7 @@ def _compute_multi_write_data( for access_node in state.data_nodes(): if state.in_degree(access_node) == 0: continue - if is_view(access_node, sdfg): + if self.is_view(access_node, sdfg): # This is an over approximation. multi_write_data.update([access_node.data, track_view(access_node, state, sdfg).data]) elif access_node.data in data_written_to: @@ -555,13 +562,11 @@ def partition_first_outputs( # Set of intermediate nodes that we have already processed. processed_inter_nodes: Set[nodes.Node] = set() - # These are the data that is written to multiple times in this state. + # These are the data that is written to multiple times in _this_ state. # If a data is written to multiple time in a state, it could be - # classified as shared. A problem could be if, this shared node happens to - # then have zero out degree, thus dependencies are given by the edges that - # leave the second exit node and not by the output nodes of the intermediate - # node. Because some other DaCe transformation (auto optimizer) fail to - # take this into account properly they do transformations that are invalid. + # classified as shared. However, it might happen that the node has zero + # degree. This is not a problem as the maps also induced a before-after + # relationship. But some DaCe transformations do not catch this. # Thus we will never modify such intermediate nodes. if self.strict_dataflow: multi_write_data: Set[str] = self._compute_multi_write_data(state, sdfg) @@ -574,6 +579,7 @@ def partition_first_outputs( # We already processed the node, this should indicate that we should # run simplify again, or we should start implementing this case. + # TODO(phimuell): Handle this case, it is partially implemented. if intermediate_node in processed_inter_nodes: return None processed_inter_nodes.add(intermediate_node) @@ -581,14 +587,15 @@ def partition_first_outputs( # The intermediate can only have one incoming degree. It might be possible # to handle multiple incoming edges, if they all come from the top map. # However, the resulting SDFG might be invalid. - # NOTE: If needed the output degree is changed further down. + # NOTE: Allow this to happen (under certain cases) if the only producer + # is the top map. if state.in_degree(intermediate_node) != 1: return None # Now let's look at all nodes that are downstream of the intermediate node. # This, among other things, will tell us, how we have to handle this node. # NOTE: The traversal will stop at the second map. - downstream_nodes = all_nodes_between( + downstream_nodes = self.all_nodes_between( graph=state, begin=intermediate_node, end=map_entry_2, @@ -616,12 +623,6 @@ def partition_first_outputs( if intermediate_node.data in multi_write_data: return None - # If `downstream_nodes` is empty, this means that the second map entry - # was found immediately, we only allow the case that there is one - # connecting Memlet. - if (len(downstream_nodes) == 0) and state.out_degree(intermediate_node) != 1: - return None - # For us an intermediate node must always be an access node, because # everything else we do not know how to handle. It is important that # we do not test for non transient data here, because they can be @@ -655,9 +656,10 @@ def partition_first_outputs( # - No dynamic Melets. # Furthermore, we will also extract the subsets, i.e. the location they # modify inside the intermediate array. + # Since we do not allow for WCR, we do not check if the producer subsets intersects. producer_subsets: List[subsets.Subset] = [] for producer_edge in producer_edges: - if isinstance(producer_edge.src, nodes.AccessNode) and is_view(producer_edge.src, sdfg): + if isinstance(producer_edge.src, nodes.AccessNode) and self.is_view(producer_edge.src, sdfg): return None if producer_edge.data.dynamic: return None @@ -678,7 +680,7 @@ def partition_first_outputs( # Ensure that there is no multihop connection to the second map entry. if intermediate_node_out_edge.dst is not map_entry_2: - if all_nodes_between( + if self.all_nodes_between( graph=state, begin=intermediate_node_out_edge.dst, end=map_entry_2, @@ -703,6 +705,7 @@ def partition_first_outputs( # TODO(phimuell): Is this restriction necessary, I am not sure. return None consumer_subsets.append(inner_consumer_edge.data.src_subset) + assert found_second_map, f"Found '{intermediate_node}' which looked like a pure node, but is not one." assert len(consumer_subsets) != 0 # The consumer still uses the original symbols of the second map, so we must rename them. @@ -725,12 +728,10 @@ def partition_first_outputs( # Note that "removed" here means that it is reconstructed by a new # output of the second map. if len(downstream_nodes) != 0: - # The intermediate node is connected to more node inside this state, - # that are not inside the map, so we must keep it alive. + # The intermediate node is connected to more node inside _this_ state. shared_outputs.add(out_edge) elif self.is_shared_data(intermediate_node, sdfg): - # The intermediate data is refered to somewhere else. - # So it must be passed. + # The intermediate data is used somewhere else, either in this or another state. shared_outputs.add(out_edge) else: # The intermediate can be removed, as it is not used anywhere else. @@ -741,253 +742,146 @@ def partition_first_outputs( return (pure_outputs, exclusive_outputs, shared_outputs) -def all_nodes_between( - graph: Union[dace.SDFG, dace.SDFGState], - begin: nodes.Node, - end: nodes.Node, - reverse: bool = False, -) -> Union[Set[nodes.Node], None]: - """Find all nodes that are reachable from `begin` but bound by `end`. - - Essentially the function starts a DFS at `begin`. If an edge is found that lead - to `end`, this edge is ignored. It will thus found any node that is reachable - from `begin` by a path that does not involve `end`. The returned set will - never contain `end` nor `begin`. In case `end` is never found the function - will return `None`. - - If `reverse` is set to `True` the function will start exploring at `end` and - follows the outgoing edges, i.e. the meaning of `end` and `begin` are swapped. + def all_nodes_between( + self, + graph: Union[dace.SDFG, dace.SDFGState], + begin: nodes.Node, + end: nodes.Node, + reverse: bool = False, + ) -> Union[Set[nodes.Node], None]: + """Find all nodes that are reachable from `begin` but bound by `end`. + + Essentially the function starts a DFS at `begin`. If an edge is found that lead + to `end`, this edge is ignored. It will thus found any node that is reachable + from `begin` by a path that does not involve `end`. The returned set will + never contain `end` nor `begin`. In case `end` is never found the function + will return `None`. + + If `reverse` is set to `True` the function will start exploring at `end` and + follows the outgoing edges, i.e. the meaning of `end` and `begin` are swapped. - Args: - graph: The graph to operate on. - begin: The start of the DFS. - end: The terminator node of the DFS. - reverse: Perform a backward DFS. - - Notes: - - The returned set will also contain the nodes of path that starts at - `begin` and ends at a node that is not `end`. - """ + Args: + graph: The graph to operate on. + begin: The start of the DFS. + end: The terminator node of the DFS. + reverse: Perform a backward DFS. + + Notes: + - The returned set will also contain the nodes of path that starts at + `begin` and ends at a node that is not `end`. + """ - def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: if reverse: - return (edge.src for edge in graph.in_edges(node)) - return (edge.dst for edge in graph.out_edges(node)) - - if reverse: - begin, end = end, begin - - to_visit: List[nodes.Node] = [begin] - seen: Set[nodes.Node] = set() - found_end: bool = False - - while len(to_visit) > 0: - n: nodes.Node = to_visit.pop() - if n == end: - found_end = True - continue - elif n in seen: - continue - seen.add(n) - to_visit.extend(next_nodes(n)) - - if not found_end: - return None + def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: + return (edge.src for edge in graph.in_edges(node)) + begin, end = end, begin + else: + def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: + return (edge.dst for edge in graph.out_edges(node)) - seen.discard(begin) - return seen + to_visit: List[dace_nodes.Node] = [begin] + seen: Set[dace_nodes.Node] = set() + while len(to_visit) > 0: + node: dace_nodes.Node = to_visit.pop() + if node != end and node not in seen: + to_visit.extend(next_nodes(node)) + seen.add(node) -def is_parallel( - graph: Union[dace.SDFG, dace.SDFGState], - node1: nodes.Node, - node2: nodes.Node, -) -> bool: - """Tests if `node1` and `node2` are parallel. + # If `end` was not found we have to return `None` to indicate this. + # `begin` and `end` are not included in the output set. + if end not in seen: + return None + return seen - {begin, end} - The nodes are parallel if `node2` can not be reached from `node1` and vice versa. - Args: - graph: The graph to traverse. - node1: The first node to check. - node2: The second node to check. - """ + def get_access_set( + self, + scope_node: Union[nodes.MapEntry, nodes.MapExit], + state: SDFGState, + ) -> Set[nodes.AccessNode]: + """Computes the access set of a "scope node". - # The `all_nodes_between()` function traverse the graph and returns `None` if - # `end` was not found. We have to call it twice, because we do not know - # which node is upstream if they are not parallel. - if all_nodes_between(graph=graph, begin=node1, end=node2) is not None: - return False - elif all_nodes_between(graph=graph, begin=node2, end=node1) is not None: - return False - return True - - -def find_downstream_consumers( - state: dace.SDFGState, - begin: Union[nodes.Node, graph.MultiConnectorEdge[dace.Memlet]], - only_tasklets: bool = False, - reverse: bool = False, -) -> Set[Tuple[nodes.Node, graph.MultiConnectorEdge[dace.Memlet]]]: - """Find all downstream connectors of `begin`. - - A consumer, in for this function, is any node that is neither an entry nor - an exit node. The function returns a set of pairs, the first element is the - node that acts as consumer and the second is the edge that leads to it. - By setting `only_tasklets` the nodes the function finds are only Tasklets. - - To find this set the function starts a search at `begin`, however, it is also - possible to pass an edge as `begin`. - If `reverse` is `True` the function essentially finds the producers that are - upstream. + If `scope_node` is a `MapEntry` node it will operate on the set of incoming + edges and if it is an `MapExit` node on the set of outgoing edges. The + function will then determine all access nodes that have a connection through + these edges to the scope nodes (edges that does not lead to access nodes are + ignored). + The function returns a set that contains all access nodes that were found. + It is important that this set will also contain views. - Args: - state: The state in which to look for the consumers. - begin: The initial node that from which the search starts. - only_tasklets: Return only Tasklets. - reverse: Follow the reverse direction. - """ - if isinstance(begin, graph.MultiConnectorEdge): - to_visit: List[graph.MultiConnectorEdge[dace.Memlet]] = [begin] - elif reverse: - to_visit = list(state.in_edges(begin)) - else: - to_visit = list(state.out_edges(begin)) - seen: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() - found: Set[Tuple[nodes.Node, graph.MultiConnectorEdge[dace.Memlet]]] = set() - - while len(to_visit) != 0: - curr_edge: graph.MultiConnectorEdge[dace.Memlet] = to_visit.pop() - next_node: nodes.Node = curr_edge.src if reverse else curr_edge.dst - - if curr_edge in seen: - continue - seen.add(curr_edge) - - if isinstance(next_node, (nodes.MapEntry, nodes.MapExit)): - if reverse: - target_conn = curr_edge.src_conn[4:] - new_edges = state.in_edges_by_connector(curr_edge.src, "IN_" + target_conn) - else: - # In forward mode a Map entry could also mean the definition of a - # dynamic map range. - if (not curr_edge.dst_conn.startswith("IN_")) and isinstance( - next_node, nodes.MapEntry - ): - # This edge defines a dynamic map range, which is a consumer - if not only_tasklets: - found.add((next_node, curr_edge)) - continue - target_conn = curr_edge.dst_conn[3:] - new_edges = state.out_edges_by_connector(curr_edge.dst, "OUT_" + target_conn) - to_visit.extend(new_edges) - del new_edges + Args: + scope_node: The scope node that should be evaluated. + state: The state in which we operate. + """ + if isinstance(scope_node, nodes.MapEntry): + get_edges = lambda node: state.in_edges(node) + other_node = lambda e: e.src else: - if only_tasklets and (not isinstance(next_node, nodes.Tasklet)): - continue - found.add((next_node, curr_edge)) - - return found - - -def find_upstream_producers( - state: dace.SDFGState, - begin: Union[nodes.Node, graph.MultiConnectorEdge[dace.Memlet]], - only_tasklets: bool = False, -) -> Set[Tuple[nodes.Node, graph.MultiConnectorEdge[dace.Memlet]]]: - """Same as `find_downstream_consumers()` but with `reverse` set to `True`.""" - return find_downstream_consumers( - state=state, - begin=begin, - only_tasklets=only_tasklets, - reverse=True, - ) - - -def get_access_set( - scope_node: Union[nodes.MapEntry, nodes.MapExit], - state: SDFGState, -) -> Set[nodes.AccessNode]: - """Computes the access set of a "scope node". + get_edges = lambda node: state.out_edges(node) + other_node = lambda e: e.dst + access_set: Set[nodes.AccessNode] = { + node + for node in map(other_node, get_edges(scope_node)) + if isinstance(node, nodes.AccessNode) + } + # As far as I know in a valid SDFG this should not happen. + assert len(access_set) == len({node.data for node in access_set}) + return access_set - If `scope_node` is a `MapEntry` node it will operate on the set of incoming - edges and if it is an `MapExit` node on the set of outgoing edges. The - function will then determine all access nodes that have a connection through - these edges to the scope nodes (edges that does not lead to access nodes are - ignored). - The function returns a set that contains all access nodes that were found. - It is important that this set will also contain views. - Args: - scope_node: The scope node that should be evaluated. - state: The state in which we operate. - """ - if isinstance(scope_node, nodes.MapEntry): - get_edges = lambda node: state.in_edges(node) - other_node = lambda e: e.src - else: - get_edges = lambda node: state.out_edges(node) - other_node = lambda e: e.dst - access_set: Set[nodes.AccessNode] = { - node - for node in map(other_node, get_edges(scope_node)) - if isinstance(node, nodes.AccessNode) - } - # As far as I know in a valid SDFG this should not happen. - assert len(access_set) == len({node.data for node in access_set}) - return access_set - - -def is_view( - node: nodes.AccessNode, - sdfg: SDFG, -) -> bool: - """Tests if `node` points to a view or not.""" - node_desc: data.Data = node.desc(sdfg) - return isinstance(node_desc, data.View) + def is_view( + self, + node: nodes.AccessNode, + sdfg: SDFG, + ) -> bool: + """Tests if `node` points to a view or not.""" + node_desc: data.Data = node.desc(sdfg) + return isinstance(node_desc, data.View) -def track_view( - view: nodes.AccessNode, - state: SDFGState, - sdfg: SDFG, -) -> nodes.AccessNode: - """Find the original data of a View. + def track_view( + self, + view: nodes.AccessNode, + state: SDFGState, + sdfg: SDFG, + ) -> nodes.AccessNode: + """Find the original data of a View. - Given the View `view`, the function will trace the view back to the - original access node. - For convenience, if `view` is not a `View` but a normal data descriptor, - then the function will return the argument unmodified. + Given the View `view`, the function will trace the view back to the + original access node. + For convenience, if `view` is not a `View` but a normal data descriptor, + then the function will return the argument unmodified. - Args: - view: The view that should be traced. - state: The state in which we operate. - sdfg: The SDFG on which we operate. - """ + Args: + view: The view that should be traced. + state: The state in which we operate. + sdfg: The SDFG on which we operate. + """ - # Test if it is a view at all, if not return the passed node as source. - if not is_view(view, sdfg): - return view + # Test if it is a view at all, if not return the passed node as source. + if not is_view(view, sdfg): + return view - # First determine if the view is used for reading or writing. - curr_edge = dace.sdfg.utils.get_view_edge(state, view) - if curr_edge is None: - raise RuntimeError(f"Failed to determine the direction of the view '{view}'.") - if curr_edge.dst_conn == "view": - # The view is used for reading. - next_node = lambda curr_edge: curr_edge.src - elif curr_edge.src_conn == "view": - # The view is used for writing. - next_node = lambda curr_edge: curr_edge.dst - else: - raise RuntimeError("Failed to determine the direction of the view '{view}'.") - - # Now trace the view back. - org_view = view - view = next_node(curr_edge) - while is_view(view, sdfg): + # First determine if the view is used for reading or writing. curr_edge = dace.sdfg.utils.get_view_edge(state, view) if curr_edge is None: - raise RuntimeError(f"View tracing of '{org_view}' failed at note '{view}'.") + raise RuntimeError(f"Failed to determine the direction of the view '{view}'.") + if curr_edge.dst_conn == "view": + # The view is used for reading. + next_node = lambda curr_edge: curr_edge.src + elif curr_edge.src_conn == "view": + # The view is used for writing. + next_node = lambda curr_edge: curr_edge.dst + else: + raise RuntimeError("Failed to determine the direction of the view '{view}'.") + + # Now trace the view back. + org_view = view view = next_node(curr_edge) - return view + while is_view(view, sdfg): + curr_edge = dace.sdfg.utils.get_view_edge(state, view) + if curr_edge is None: + raise RuntimeError(f"View tracing of '{org_view}' failed at note '{view}'.") + view = next_node(curr_edge) + return view diff --git a/dace/transformation/dataflow/map_fusion_parallel.py b/dace/transformation/dataflow/map_fusion_parallel.py index 414a565d7e..5a0e0045c9 100644 --- a/dace/transformation/dataflow/map_fusion_parallel.py +++ b/dace/transformation/dataflow/map_fusion_parallel.py @@ -12,11 +12,11 @@ @properties.make_properties class ParallelMapFusion(map_fusion_helper.MapFusionHelper): - """The `ParallelMapFusion` transformation allows to merge two parallel maps together. + """The `ParallelMapFusion` transformation allows to merge two parallel maps. - The `SerialMapFusion` transformation is only able to handle maps that are sequential, - however, this transformation is able to fuse _any_ maps that are not sequential - and are in the same scope. + While the `SerialMapFusion` transformation fuses maps that are sequentially + connected by an intermediate node, this transformation is able to fuse any + two maps that are not sequential and in the same scope. Args: only_if_common_ancestor: Only perform fusion if both Maps share at least one @@ -26,7 +26,7 @@ class ParallelMapFusion(map_fusion_helper.MapFusionHelper): Note: This transformation only matches the entry nodes of the Map, but will also - modify the exit nodes of the Map. + modify the exit nodes of the Maps. """ map_entry1 = transformation.transformation.PatternNode(nodes.MapEntry) @@ -65,7 +65,6 @@ def can_be_applied( sdfg: dace.SDFG, permissive: bool = False, ) -> bool: - """The transformation is applicable.""" map_entry_1: nodes.MapEntry = self.map_entry1 map_entry_2: nodes.MapEntry = self.map_entry2 @@ -83,7 +82,7 @@ def can_be_applied( # Since the match expression matches any twp Maps, we have to ensure that # the maps are parallel. The `can_be_fused()` function already verified # if they are in the same scope. - if not map_fusion_helper.is_parallel(graph=graph, node1=map_entry_1, node2=map_entry_2): + if not self._is_parallel(graph=graph, node1=map_entry_1, node2=map_entry_2): return False # Test if they have they share a node as direct ancestor. @@ -96,12 +95,42 @@ def can_be_applied( return True + def _is_parallel( + self, + graph: SDFGState, + node1: nodes.Node, + node2: nodes.Node, + ) -> bool: + """Tests if `node1` and `node2` are parallel. + + The nodes are parallel if `node2` can not be reached from `node1` and vice versa. + + Args: + graph: The graph to traverse. + node1: The first node to check. + node2: The second node to check. + """ + + # In order to be parallel they must be in the same scope. + scope = graph.scope_dict() + if scope[node1] != scope[node2]: + return False + + # The `all_nodes_between()` function traverse the graph and returns `None` if + # `end` was not found. We have to call it twice, because we do not know + # which node is upstream if they are not parallel. + if self.all_nodes_between(graph=graph, begin=node1, end=node2) is not None: + return False + elif self.all_nodes_between(graph=graph, begin=node2, end=node1) is not None: + return False + return True + + def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: """Performs the Map fusing. - Essentially, the function relocate all edges from the nodes forming the second - Map to the corresponding nodes of the first Map. Afterwards the nodes of the - second Map are removed. + Essentially, the function relocate all edges from the scope nodes (`MapEntry` + and `MapExit`) of the second map to the scope nodes of the first map. """ map_entry_1: nodes.MapEntry = self.map_entry1 diff --git a/dace/transformation/dataflow/map_fusion_serial.py b/dace/transformation/dataflow/map_fusion_serial.py index 8bab1b6007..4b13a36bc1 100644 --- a/dace/transformation/dataflow/map_fusion_serial.py +++ b/dace/transformation/dataflow/map_fusion_serial.py @@ -151,7 +151,7 @@ def has_read_write_dependency( for scope_node in [map_entry_1, map_exit_1, map_entry_2, map_exit_2]: access_sets.append({ node.data: node - for node in mfh.get_access_set(scope_node, state) + for node in self.get_access_set(scope_node, state) }) read_map_1, write_map_1, read_map_2, write_map_2 = access_sets @@ -161,7 +161,7 @@ def has_read_write_dependency( resolved_sets: List[Set[str]] = [] for unresolved_set in [read_map_1, write_map_1, read_map_2, write_map_2]: resolved_sets.append({ - mfh.track_view(node).data if mfh.is_view(node, sdfg) else node.data + self.track_view(node).data if self.is_view(node, sdfg) else node.data for node in unresolved_set.values() }) real_read_map_1, real_write_map_1, real_read_map_2, real_write_map_2 = resolved_sets @@ -195,7 +195,7 @@ def has_read_write_dependency( # For simplicity we assume that the nodes used to exchange information can # not be a View. This is a simplification. - if any(mfh.is_view(exchange_node, sdfg) for exchange_node in exchange_nodes.values()): + if any(self.is_view(exchange_node, sdfg) for exchange_node in exchange_nodes.values()): return True # This is the names of the node that are used as input of the first map and @@ -205,7 +205,7 @@ def has_read_write_dependency( # Because it is hard, we do not allow Views here, because we can not resolve # access sets (at least I can not). - if any(mfh.is_view(read_map_1[name], sdfg) for name in fused_inout_data_names): + if any(self.is_view(read_map_1[name], sdfg) for name in fused_inout_data_names): return True # This is a case that can not be handled, the above code should filter this @@ -254,7 +254,7 @@ def _check_read_write_dependency_exchange_nodes( state: SDFGState, sdfg: SDFG, ) -> bool: - """Checks if there are any rw dependencies in the exchange set. + """Checks if there are any read after write dependencies in the exchange set. Args: map_exit_1: Exit node of the first (top) map; defines writes. From 885802192691c9115171ca97c7df8b1171dbbae1 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 5 Sep 2024 08:14:08 +0200 Subject: [PATCH 048/115] Refined my fix in the `_read_and_write_set()` function. I realized that an undefined dst_subset of an incomming edge of an access node is only possible if the access node is a scalar. So I modified the function such that the corrct subset in that case is generated. This case happens if a single entry of an array is read and written into a scalar. Then only the source subset must be defined. --- dace/sdfg/state.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 08d5229c7d..6daf2aba62 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -762,10 +762,15 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, # Filter out memlets which go out but the same data is written to the AccessNode by another memlet for out_edge in list(out_edges): for in_edge in list(in_edges): - if (in_edge.data.data == out_edge.data.data - and (in_edge.data.dst_subset is not None and in_edge.data.dst_subset.covers(out_edge.data.src_subset))): - out_edges.remove(out_edge) - break + if in_edge.data.data == out_edge.data.data: + if in_edge.data.dst_subset is None: + assert isinstance(n.desc(self.sdfg), dt.Scalar) or n.desc(self.sdfg).total_size == 1 + dst_subset = sbs.Range.from_array(n.desc(self.sdfg)) + else: + dst_subset = in_edge.data.dst_subset + if dst_subset.covers(out_edge.data.src_subset): + out_edges.remove(out_edge) + break for e in in_edges: # skip empty memlets From 09deb95916a52c5fd9151673b93a184b4b66618e Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 5 Sep 2024 09:46:19 +0200 Subject: [PATCH 049/115] Made some fixes to the _read_and_write_sets function. It now uses the proper subsets. However, it is not as efficent as possible, as it does some work multiple time. --- dace/sdfg/state.py | 57 +++++++++++++++++++++++----------------------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 6daf2aba62..3de8277074 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -755,34 +755,35 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, ws = collections.defaultdict(list) # Traverse in topological order, so data that is written before it # is read is not counted in the read set + # TODO: This only works if every data descriptor is only once in a path. for n in utils.dfs_topological_sort(sg, sources=sg.source_nodes()): - if isinstance(n, nd.AccessNode): - in_edges = sg.in_edges(n) - out_edges = sg.out_edges(n) - # Filter out memlets which go out but the same data is written to the AccessNode by another memlet - for out_edge in list(out_edges): - for in_edge in list(in_edges): - if in_edge.data.data == out_edge.data.data: - if in_edge.data.dst_subset is None: - assert isinstance(n.desc(self.sdfg), dt.Scalar) or n.desc(self.sdfg).total_size == 1 - dst_subset = sbs.Range.from_array(n.desc(self.sdfg)) - else: - dst_subset = in_edge.data.dst_subset - if dst_subset.covers(out_edge.data.src_subset): - out_edges.remove(out_edge) - break - - for e in in_edges: - # skip empty memlets - if e.data.is_empty(): - continue - # Store all subsets that have been written - ws[n.data].append(e.data.subset) - for e in out_edges: - # skip empty memlets - if e.data.is_empty(): - continue - rs[n.data].append(e.data.subset) + if not isinstance(n, nd.AccessNode): + continue + ac_desc = n.desc(self.sdfg) + in_edges = [in_edge for in_edge in sg.in_edges(n) if not in_edge.data.is_empty()] + out_edges = [out_edge for out_edge in sg.out_edges(n) if not out_edge.data.is_empty()] + + # Filter out memlets which go out but the same data is written to the AccessNode by another memlet + for out_edge in list(out_edges): + assert (out_edge.data.src_subset is not None) or (isinstance(ac_desc, dt.Scalar) or ac_desc.total_size == 1) + src_subset = ( + sbs.Range.from_array(ac_desc) + if out_edge.data.src_subset is None + else out_edge.data.src_subset + ) + for in_edge in in_edges: + assert in_edge.data.dst_subset is not None or (isinstance(ac_desc, dt.Scalar) or ac_desc.total_size == 1) + dst_subset = ( + sbs.Range.from_array(ac_desc) + if in_edge.data.dst_subset is None + else in_edge.data.dst_subset + ) + if dst_subset.covers(src_subset): + out_edges.remove(out_edge) + break + ws[n.data].extend(sbs.Range.from_array(ac_desc) if e.data.dst_subset is None else e.data.dst_subset for e in in_edges) + rs[n.data].extend(sbs.Range.from_array(ac_desc) if e.data.src_subset is None else e.data.src_subset for e in out_edges) + # Union all subgraphs, so an array that was excluded from the read # set because it was written first is still included if it is read # in another subgraph @@ -790,7 +791,7 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, read_set[data] += accesses for data, accesses in ws.items(): write_set[data] += accesses - return read_set, write_set + return copy.deepcopy((read_set, write_set)) def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: """ From 1e14f262c4bf55c81ffc6ae6d14847978c749c86 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 5 Sep 2024 09:59:49 +0200 Subject: [PATCH 050/115] Updated the _read_and_write_sets() function. It is now a bit more stream lined. --- dace/sdfg/state.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 3de8277074..db14c0cf4e 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -760,29 +760,36 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, if not isinstance(n, nd.AccessNode): continue ac_desc = n.desc(self.sdfg) + ac_size = ac_desc.total_size in_edges = [in_edge for in_edge in sg.in_edges(n) if not in_edge.data.is_empty()] out_edges = [out_edge for out_edge in sg.out_edges(n) if not out_edge.data.is_empty()] - # Filter out memlets which go out but the same data is written to the AccessNode by another memlet + # In some conditions subsets can be `None`, we will now clean them. + in_subsets = dict() + for in_edge in in_edges: + assert in_edge.data.dst_subset is not None or (isinstance(ac_desc, dt.Scalar) or ac_size == 1) + in_subsets[in_edge] = ( + sbs.Range.from_array(ac_desc) + if in_edge.data.dst_subset is None + else in_edge.data.dst_subset + ) + out_subsets = dict() for out_edge in list(out_edges): - assert (out_edge.data.src_subset is not None) or (isinstance(ac_desc, dt.Scalar) or ac_desc.total_size == 1) - src_subset = ( + assert (out_edge.data.src_subset is not None) or (isinstance(ac_desc, dt.Scalar) or ac_size == 1) + out_subsets[out_edge] = ( sbs.Range.from_array(ac_desc) if out_edge.data.src_subset is None else out_edge.data.src_subset ) + + # Filter out memlets which go out but the same data is written to the AccessNode by another memlet + for out_edge in list(out_edges): for in_edge in in_edges: - assert in_edge.data.dst_subset is not None or (isinstance(ac_desc, dt.Scalar) or ac_desc.total_size == 1) - dst_subset = ( - sbs.Range.from_array(ac_desc) - if in_edge.data.dst_subset is None - else in_edge.data.dst_subset - ) - if dst_subset.covers(src_subset): + if in_subsets[in_edge].covers(out_subsets[out_edge]): out_edges.remove(out_edge) break - ws[n.data].extend(sbs.Range.from_array(ac_desc) if e.data.dst_subset is None else e.data.dst_subset for e in in_edges) - rs[n.data].extend(sbs.Range.from_array(ac_desc) if e.data.src_subset is None else e.data.src_subset for e in out_edges) + ws[n.data].extend(in_subsets.values()) + rs[n.data].extend(out_subsets[out_edge] for out_edge in out_edges) # Union all subgraphs, so an array that was excluded from the read # set because it was written first is still included if it is read From 92c7097610a4f7bea1dffed0b0c2f19c8c70395d Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 5 Sep 2024 11:01:53 +0200 Subject: [PATCH 051/115] Added a new test for the SDFG fusion. --- tests/transformations/mapfusion_test.py | 69 +++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index d2b12347ff..071a427228 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -173,6 +173,52 @@ def fusion_indirect_access(A: dace.float32[100], B: dace.float32[100], idx: dace out[:] = tmp[idx] +def make_interstate_transient_fusion_sdfg(): + sdfg = dace.SDFG("interstate_transient_fusion") + state1 = sdfg.add_state("state1", is_start_block=True) + state2 = sdfg.add_state_after(state1, "state2") + + for name in ["A", "B", "C", "D"]: + sdfg.add_array(name, shape=(20, 20), dtype=dace.float64, transient=False) + sdfg.arrays["B"].transient = True + + A1, B1, C1 = (state1.add_access(name) for name in ["A", "B", "C"]) + state1.add_mapped_tasklet( + "map_1_1", + map_ranges={"__i0": "0:20", "__i1": "0:20"}, + inputs={"__in1": dace.Memlet("A[__i0, __i1]")}, + code="__out = __in1 + 20", + outputs={"__out": dace.Memlet("B[__i0, __i1]")}, + input_nodes={"A": A1}, + output_nodes={"B": B1}, + external_edges=True, + ) + state1.add_mapped_tasklet( + "map_2_1", + map_ranges={"__i0": "0:20", "__i1": "0:20"}, + inputs={"__in1": dace.Memlet("B[__i0, __i1]")}, + code="__out = __in1 + 10", + outputs={"__out": dace.Memlet("C[__i0, __i1]")}, + input_nodes={"B": B1}, + output_nodes={"C": C1}, + external_edges=True, + ) + + B2, D2 = (state2.add_access(name) for name in ["B", "D"]) + state2.add_mapped_tasklet( + "map_1_2", + map_ranges={"__i0": "0:20", "__i1": "0:20"}, + inputs={"__in1": dace.Memlet("B[__i0, __i1]")}, + code="__out = __in1 + 6", + outputs={"__out": dace.Memlet("D[__i0, __i1]")}, + input_nodes={"B": B2}, + output_nodes={"D": D2}, + external_edges=True, + ) + + return sdfg, state1, state2 + + def test_fusion_simple(): sdfg = fusion_simple.to_sdfg(simplify=True) sdfg = apply_fusion(sdfg, final_maps=1) @@ -438,6 +484,28 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3 assert isinstance(src, dace.nodes.AccessNode) +def test_interstate_fusion(): + """Transient between two maps is used in another state and must become shared. + """ + sdfg, state1, state2 = make_interstate_transient_fusion_sdfg() + + A = np.random.rand(20, 20) + C = np.random.rand(20, 20) + D = np.random.rand(20, 20) + + ref_C = A + 30 + ref_D = A + 26 + + assert sdfg.apply_transformations_repeated(SerialMapFusion, validate=True, validate_all=True) == 1 + assert sdfg.number_of_nodes() == 2 + assert len([node for node in state1.data_nodes() if node.data == "B"]) == 1 + + sdfg(A=A, C=C, D=D) + + assert np.allclose(C, ref_C) + assert np.allclose(D, ref_D) + + def test_parallel_fusion_simple(): N1, N2 = 10, 20 @@ -523,6 +591,7 @@ def _make_sdfg(): test_fusion_with_inverted_indices() test_fusion_with_empty_memlet() test_fusion_with_nested_sdfg_0() + test_interstate_fusion() test_fusion_with_nested_sdfg_1() test_parallel_fusion_simple() print("SUCCESS") From 041b3da6982345e5266feadcc1b2ced33e4e9f77 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 5 Sep 2024 11:10:34 +0200 Subject: [PATCH 052/115] Fixed how the interstate stuff works. --- dace/transformation/dataflow/map_fusion_helper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index a9247bf591..d358aeb8ff 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -409,6 +409,7 @@ def _compute_shared_data( # - The access node has output degree larger than 1 (input degrees larger # than one, will always be partitioned as shared anyway). prevously_seen_data: Set[str] = set() + interstate_read_symbols: Set[str] = set() for state in sdfg.nodes(): for access_node in state.data_nodes(): @@ -446,9 +447,8 @@ def _compute_shared_data( prevously_seen_data.add(access_node.data) # Now we are collecting all symbols that interstate edges read from. - interstate_read_symbols: Set[str] = set() for edge in sdfg.edges(): - interstate_read_symbols.update(edge.read_symbols()) + interstate_read_symbols.update(edge.data.read_symbols()) # We also have to keep everything the edges referrers to and is an array. shared_data.update(interstate_read_symbols.intersection(prevously_seen_data)) From 5556bd3be3b9f42ca2e48e025efeb91580161739 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 5 Sep 2024 11:13:32 +0200 Subject: [PATCH 053/115] Fixed an error. --- dace/transformation/dataflow/map_fusion_helper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index d358aeb8ff..5cfe0737a1 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -860,7 +860,7 @@ def track_view( """ # Test if it is a view at all, if not return the passed node as source. - if not is_view(view, sdfg): + if not self.is_view(view, sdfg): return view # First determine if the view is used for reading or writing. @@ -874,7 +874,7 @@ def track_view( # The view is used for writing. next_node = lambda curr_edge: curr_edge.dst else: - raise RuntimeError("Failed to determine the direction of the view '{view}'.") + raise RuntimeError(f"Failed to determine the direction of the view '{view}'.") # Now trace the view back. org_view = view From e3461c3b6ceb8bff20c39d7f1aa641d70638215d Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 5 Sep 2024 11:19:41 +0200 Subject: [PATCH 054/115] How did that thing even run. --- dace/transformation/dataflow/map_fusion_helper.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index 5cfe0737a1..d52e8d4a90 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -867,19 +867,19 @@ def track_view( curr_edge = dace.sdfg.utils.get_view_edge(state, view) if curr_edge is None: raise RuntimeError(f"Failed to determine the direction of the view '{view}'.") - if curr_edge.dst_conn == "view": + if curr_edge.dst_conn == "views": # The view is used for reading. next_node = lambda curr_edge: curr_edge.src - elif curr_edge.src_conn == "view": + elif curr_edge.src_conn == "views": # The view is used for writing. next_node = lambda curr_edge: curr_edge.dst else: - raise RuntimeError(f"Failed to determine the direction of the view '{view}'.") + raise RuntimeError(f"Failed to determine the direction of the view '{view}' | {curr_edge}.") # Now trace the view back. org_view = view view = next_node(curr_edge) - while is_view(view, sdfg): + while self.is_view(view, sdfg): curr_edge = dace.sdfg.utils.get_view_edge(state, view) if curr_edge is None: raise RuntimeError(f"View tracing of '{org_view}' failed at note '{view}'.") From 6f5edc56b9729d178a8c8ed3fa0d827cfa030582 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 5 Sep 2024 11:44:51 +0200 Subject: [PATCH 055/115] Changed the default value of the `strict_dataflow` flag (the compability flag in `MapFusion`). Now the compability mode is disabled by default, however, in auto optimizer this value is set explicitly to True. This means map fusion will use the compability mode. --- dace/transformation/dataflow/map_fusion_helper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index d52e8d4a90..59a4b4040c 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -21,7 +21,7 @@ class MapFusionHelper(transformation.SingleStateTransformation): Args: only_inner_maps: Only match Maps that are internal, i.e. inside another Map. only_toplevel_maps: Only consider Maps that are at the top. - strict_dataflow: If `True`, the default, the transformation ensures a more + strict_dataflow: If `True`, the transformation ensures a more stricter version of the data flow. Note: @@ -56,7 +56,7 @@ class MapFusionHelper(transformation.SingleStateTransformation): ) strict_dataflow = properties.Property( dtype=bool, - default=True, + default=False, desc="If `True` then the transformation will ensure a more stricter data flow.", ) From 35a5426930fe221a5657aedd5e6fe7b74c641145 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 5 Sep 2024 13:01:23 +0200 Subject: [PATCH 056/115] Fixed the _read_and_write_sets(). If there are no edges we should not add anything. The old for loop had this property. --- dace/sdfg/state.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index db14c0cf4e..454be2de09 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -788,8 +788,10 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, if in_subsets[in_edge].covers(out_subsets[out_edge]): out_edges.remove(out_edge) break + ws[n.data].extend(in_subsets.values()) - rs[n.data].extend(out_subsets[out_edge] for out_edge in out_edges) + if out_edges: + rs[n.data].extend(out_subsets[out_edge] for out_edge in out_edges) # Union all subgraphs, so an array that was excluded from the read # set because it was written first is still included if it is read From 20b728e9f4c552550d7060390ff121493ee43a7f Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 5 Sep 2024 13:46:50 +0200 Subject: [PATCH 057/115] Updated `_read_and_write_sets()`. It seems that some cache is not updated correctly, this makes the `{src,dst}_subset` attributes of the Memlet wrong. Since I have no idea where they are we now initialize the edges. I am not saying that this is good, I just say that it works. --- dace/sdfg/state.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 454be2de09..a255d3d669 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -746,6 +746,12 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, Determines what data is read and written in this subgraph, returning dictionaries from data containers to all subsets that are read/written. """ + + # Ensures that the `{src,dst}_subset` are properly set. + # TODO: find where the problems are + for edge in self.edges(): + edge.data.try_initialize(self.sdfg, self, edge) + read_set = collections.defaultdict(list) write_set = collections.defaultdict(list) from dace.sdfg import utils # Avoid cyclic import From 5dd9c6b84d995e71d9a326b0f4f16a295f9e397e Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 5 Sep 2024 13:48:49 +0200 Subject: [PATCH 058/115] Refined the test in the memlet extension of _read_and_write_sets. --- dace/sdfg/state.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index a255d3d669..1b0822c572 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -773,7 +773,7 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, # In some conditions subsets can be `None`, we will now clean them. in_subsets = dict() for in_edge in in_edges: - assert in_edge.data.dst_subset is not None or (isinstance(ac_desc, dt.Scalar) or ac_size == 1) + assert in_edge.data.dst_subset is not None or (in_edge.data.dynamic or in_edge.data.volume <= 0 or in_edge.data.num_elements() == ac_size) in_subsets[in_edge] = ( sbs.Range.from_array(ac_desc) if in_edge.data.dst_subset is None @@ -781,7 +781,7 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, ) out_subsets = dict() for out_edge in list(out_edges): - assert (out_edge.data.src_subset is not None) or (isinstance(ac_desc, dt.Scalar) or ac_size == 1) + assert (out_edge.data.src_subset is not None) or (out_edge.data.dynamic or out_edge.data.volume <= 0 or out_edge.data.num_elements() == ac_size) out_subsets[out_edge] = ( sbs.Range.from_array(ac_desc) if out_edge.data.src_subset is None From 617fb8f923b69a883781eb1736acedc36efb9678 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 5 Sep 2024 13:55:10 +0200 Subject: [PATCH 059/115] Now the `transformations/move_loop_into_map_test.py::MoveLoopIntoMapTest::test_more_than_a_map` test is able to apply the transformation. I looked very hard at the SDFG any I am sure that it is possible. Also the description says that it is only potentially a dependency. I am not sure if I have made an error, since it does not feel right to change tests. --- .../move_loop_into_map_test.py | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/tests/transformations/move_loop_into_map_test.py b/tests/transformations/move_loop_into_map_test.py index dca775bb7a..e9391ef8ae 100644 --- a/tests/transformations/move_loop_into_map_test.py +++ b/tests/transformations/move_loop_into_map_test.py @@ -1,5 +1,6 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. import dace +import copy from dace.transformation.interstate import MoveLoopIntoMap import unittest import numpy as np @@ -170,8 +171,26 @@ def test_more_than_a_map(self): body.add_nedge(aread, oread, dace.Memlet.from_array('A', aarr)) body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) sdfg.add_loop(None, body, None, '_', '0', '_ < 10', '_ + 1') + + org_data = { + "A": np.random.rand(3, 3), + "B": np.random.rand(3, 3), + "out": np.random.rand(3, 3), + } + + unopt_data = copy.deepcopy(org_data) + sdfg(**unopt_data) + + sdfg.view() count = sdfg.apply_transformations(MoveLoopIntoMap) - self.assertFalse(count > 0) + sdfg.view() + opt_data = copy.deepcopy(org_data) + sdfg(**opt_data) + + for name in org_data.keys(): + self.assertTrue(np.allclose(opt_data[name], unopt_data[name])) + self.assertTrue(count > 0) + def test_more_than_a_map_1(self): """ From 04cadbea6e45c46e0ae369aedd9d1ec6c9e78c7c Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 5 Sep 2024 15:10:04 +0200 Subject: [PATCH 060/115] If there are no in edges, then we should not add it. --- dace/sdfg/state.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 1b0822c572..3627557a4f 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -795,7 +795,8 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, out_edges.remove(out_edge) break - ws[n.data].extend(in_subsets.values()) + if in_edges: + ws[n.data].extend(in_subsets.values()) if out_edges: rs[n.data].extend(out_subsets[out_edge] for out_edge in out_edges) From d8b3547836dbeab1b489f7bbf88e031ba880c9f3 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 5 Sep 2024 15:11:37 +0200 Subject: [PATCH 061/115] Removed some stray view calls. --- tests/transformations/move_loop_into_map_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/transformations/move_loop_into_map_test.py b/tests/transformations/move_loop_into_map_test.py index e9391ef8ae..3a4140a03f 100644 --- a/tests/transformations/move_loop_into_map_test.py +++ b/tests/transformations/move_loop_into_map_test.py @@ -181,9 +181,7 @@ def test_more_than_a_map(self): unopt_data = copy.deepcopy(org_data) sdfg(**unopt_data) - sdfg.view() count = sdfg.apply_transformations(MoveLoopIntoMap) - sdfg.view() opt_data = copy.deepcopy(org_data) sdfg(**opt_data) From 8fed4fe7c6d83d360ff90ee90e9b820293d888da Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 6 Sep 2024 08:05:51 +0200 Subject: [PATCH 062/115] Updated some checks in the _read_and_write_sets function. --- dace/sdfg/state.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 3627557a4f..0c912d4d10 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -773,15 +773,15 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, # In some conditions subsets can be `None`, we will now clean them. in_subsets = dict() for in_edge in in_edges: - assert in_edge.data.dst_subset is not None or (in_edge.data.dynamic or in_edge.data.volume <= 0 or in_edge.data.num_elements() == ac_size) + assert in_edge.data.dst_subset is not None or (in_edge.data.num_elements() == ac_size) in_subsets[in_edge] = ( sbs.Range.from_array(ac_desc) if in_edge.data.dst_subset is None else in_edge.data.dst_subset ) out_subsets = dict() - for out_edge in list(out_edges): - assert (out_edge.data.src_subset is not None) or (out_edge.data.dynamic or out_edge.data.volume <= 0 or out_edge.data.num_elements() == ac_size) + for out_edge in out_edges: + assert (out_edge.data.src_subset is not None) or (out_edge.data.num_elements() == ac_size) out_subsets[out_edge] = ( sbs.Range.from_array(ac_desc) if out_edge.data.src_subset is None From 567f459eea5adf3ee0d3c803880b2e1e4874b097 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 6 Sep 2024 08:06:37 +0200 Subject: [PATCH 063/115] Fixed the bug in `_read_and_write_sets()` that made `tests/numpy/ufunc_support_test.py::test_ufunc_add_accumulate_simple` fail. However, I am pretty sure that this is a compability bug. The function essentially removes accesses from the read set if there is a matching access in the read set. There is nothing wrong with that, however the original code only filtered that out when the data of the in- and output memlets refered to the same data. However, this is wrong, because there is no guarantee that this holds as the memlet's data member must not coincide with the node under consideration, but can be another (that is connected). But if thsi check is removed then `tests/numpy/ufunc_support_test.py::test_ufunc_add_accumulate_simple` will fail. So I put it back. --- dace/sdfg/state.py | 9 +++++++++ .../move_loop_into_map_test.py | 19 +------------------ 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 0c912d4d10..a061d0b63e 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -791,6 +791,15 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, # Filter out memlets which go out but the same data is written to the AccessNode by another memlet for out_edge in list(out_edges): for in_edge in in_edges: + if out_edge.data.data != in_edge.data.data: + # NOTE: This check does not make any sense, and is in my view wrong. + # If we consider a memlet between two access nodes, to which access + # node the `data` attribute of the memlet refers to is arbitrary and + # does not matter. However, the test will filter _some_ out but not + # all. + # This check is is retained for compatibility with `RefineNestedAccess`, + # see `tests/numpy/ufunc_support_test.py::test_ufunc_add_accumulate_simple`. + continue if in_subsets[in_edge].covers(out_subsets[out_edge]): out_edges.remove(out_edge) break diff --git a/tests/transformations/move_loop_into_map_test.py b/tests/transformations/move_loop_into_map_test.py index 3a4140a03f..dca775bb7a 100644 --- a/tests/transformations/move_loop_into_map_test.py +++ b/tests/transformations/move_loop_into_map_test.py @@ -1,6 +1,5 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. import dace -import copy from dace.transformation.interstate import MoveLoopIntoMap import unittest import numpy as np @@ -171,24 +170,8 @@ def test_more_than_a_map(self): body.add_nedge(aread, oread, dace.Memlet.from_array('A', aarr)) body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) sdfg.add_loop(None, body, None, '_', '0', '_ < 10', '_ + 1') - - org_data = { - "A": np.random.rand(3, 3), - "B": np.random.rand(3, 3), - "out": np.random.rand(3, 3), - } - - unopt_data = copy.deepcopy(org_data) - sdfg(**unopt_data) - count = sdfg.apply_transformations(MoveLoopIntoMap) - opt_data = copy.deepcopy(org_data) - sdfg(**opt_data) - - for name in org_data.keys(): - self.assertTrue(np.allclose(opt_data[name], unopt_data[name])) - self.assertTrue(count > 0) - + self.assertFalse(count > 0) def test_more_than_a_map_1(self): """ From 9cad08dae2a0d7a2680a2bb8fa219a10f594b26c Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 6 Sep 2024 09:03:21 +0200 Subject: [PATCH 064/115] Added a test to back my claims I did in `567f459e`. The test obey the bug that is mentioned. I also tested the old version, it has the bug (and some more). --- dace/sdfg/state.py | 3 +- tests/sdfg/state_test.py | 91 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 1 deletion(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index a061d0b63e..23ac2f763d 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -796,7 +796,8 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, # If we consider a memlet between two access nodes, to which access # node the `data` attribute of the memlet refers to is arbitrary and # does not matter. However, the test will filter _some_ out but not - # all. + # all. See also the tests inside `tests/sdfg/state_test.py` for the + # wrong behaviour this check induces. # This check is is retained for compatibility with `RefineNestedAccess`, # see `tests/numpy/ufunc_support_test.py::test_ufunc_add_accumulate_simple`. continue diff --git a/tests/sdfg/state_test.py b/tests/sdfg/state_test.py index eb4e97ba66..198ccc4ecf 100644 --- a/tests/sdfg/state_test.py +++ b/tests/sdfg/state_test.py @@ -1,5 +1,6 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. import dace +from dace import subsets as sbs from dace.transformation.helpers import find_sdfg_control_flow @@ -43,6 +44,7 @@ def test_read_write_set_y_formation(): assert 'B' not in state.read_and_write_sets()[0] + def test_deepcopy_state(): N = dace.symbol('N') @@ -58,7 +60,96 @@ def double_loop(arr: dace.float32[N]): sdfg.validate() +def test_read_and_write_set_filter(): + sdfg = dace.SDFG('graph') + state = sdfg.add_state('state') + sdfg.add_array('A', [2, 2], dace.float64) + sdfg.add_scalar('B', dace.float64) + sdfg.add_array('C', [2, 2], dace.float64) + A, B, C = (state.add_access(name) for name in ('A', 'B', 'C')) + + state.add_nedge( + A, + B, + dace.Memlet("B[0] -> 0, 0"), + ) + state.add_nedge( + B, + C, + # If the Memlet would be `B[0] -> 1, 1` it would then be filtered out. + # This is an intentional behaviour for compatibility. + dace.Memlet("C[1, 1] -> 0"), + ) + state.add_nedge( + B, + C, + dace.Memlet("B[0] -> 0, 0"), + ) + sdfg.validate() + + expected_reads = { + "A": [sbs.Range.from_string("0, 0")], + # See comment in `state._read_and_write_sets()` why "B" is here + # it should actually not, but it is a bug. + "B": [sbs.Range.from_string("0")], + } + expected_writes = { + # However, this should always be here. + "B": [sbs.Range.from_string("0")], + "C": [sbs.Range.from_string("0, 0"), sbs.Range.from_string("1, 1")], + } + read_set, write_set = state._read_and_write_sets() + + for expected_sets, computed_sets in [(expected_reads, read_set), (expected_writes, write_set)]: + assert expected_sets.keys() == computed_sets.keys(), f"Expected the set to contain '{expected_sets.keys()}' but got '{computed_sets.keys()}'." + for access_data in expected_sets.keys(): + for exp in expected_sets[access_data]: + found_match = False + for res in computed_sets[access_data]: + if res == exp: + found_match = True + break + assert found_match, f"Could not find the subset '{exp}' only got '{computed_sets}'" + + +def test_read_and_write_set_selection(): + sdfg = dace.SDFG('graph') + state = sdfg.add_state('state') + sdfg.add_array('A', [2, 2], dace.float64) + sdfg.add_scalar('B', dace.float64) + A, B = (state.add_access(name) for name in ('A', 'B')) + + state.add_nedge( + A, + B, + dace.Memlet("A[0, 0]"), + ) + sdfg.validate() + + expected_reads = { + "A": [sbs.Range.from_string("0, 0")], + } + expected_writes = { + "B": [sbs.Range.from_string("0")], + } + read_set, write_set = state._read_and_write_sets() + + for expected_sets, computed_sets in [(expected_reads, read_set), (expected_writes, write_set)]: + assert expected_sets.keys() == computed_sets.keys(), f"Expected the set to contain '{expected_sets.keys()}' but got '{computed_sets.keys()}'." + for access_data in expected_sets.keys(): + for exp in expected_sets[access_data]: + found_match = False + for res in computed_sets[access_data]: + if res == exp: + found_match = True + break + assert found_match, f"Could not find the subset '{exp}' only got '{computed_sets}'" + + if __name__ == '__main__': test_read_write_set() test_read_write_set_y_formation() test_deepcopy_state() + test_read_and_write_set_selection() + test_read_and_write_set_filter() + From 866d815f4c58e4ef4b9dcc15bbedc1ab02f93963 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 6 Sep 2024 09:46:08 +0200 Subject: [PATCH 065/115] Added a new test for map fusion that tests indirect accesses. --- tests/transformations/mapfusion_test.py | 50 +++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 071a427228..0f444f5382 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -579,6 +579,55 @@ def _make_sdfg(): assert np.allclose(out2, res2) +def test_fuse_indirect_accesses(): + + @dace.program(auto_optimize=False) + def inner_product( + A: dace.float32[20], + B: dace.float32[20], + idx: dace.int32[20], + out: dace.float32[20], + ): + tmp1 = np.empty_like(A) + tmp2 = np.empty_like(A) + for i in dace.map[0:20]: + tmp1[i] = A[i] * B[i] + for i in dace.map[0:20]: + tmp2[i] = tmp1[i] + A[i] + for i in dace.map[0:20]: + with dace.tasklet: + __arr << tmp2(1)[:] + __idx << idx[i] + __out >> out[i] + __out = __arr[__idx] + + sdfg = inner_product.to_sdfg(simplify=True) + assert sdfg.number_of_nodes() == 1 + assert count_node(sdfg, nodes.MapEntry) == 3 + + apply_fusion(sdfg, final_maps=2) + + # The last map, with the indirect access, can not be fused, so check that. + state = next(iter(sdfg.nodes())) + assert len(list(state.sink_nodes())) == 1 + out_node = next(iter(state.sink_nodes())) + assert out_node.data == "out" + assert state.in_degree(out_node) == 1 + + # Now find the last map and the indirect access Tasklet + last_map_exit = next(iter(state.in_edges(out_node))).src + last_map_entry = state.entry_node(last_map_exit) + assert isinstance(last_map_exit, nodes.MapExit) + assert state.in_degree(last_map_exit) == 1 + + indirect_access_tasklet = next(iter(state.in_edges(last_map_exit))).src + assert isinstance(indirect_access_tasklet, nodes.Tasklet) + assert indirect_access_tasklet.code == "__out = __arr[__idx]" # TODO: Regex with connectors + + # The tasklet can only be connected to a map entry. + assert all(in_edge.src is last_map_entry for in_edge in state.in_edges(indirect_access_tasklet)) + + if __name__ == '__main__': test_indirect_accesses() test_fusion_shared() @@ -594,5 +643,6 @@ def _make_sdfg(): test_interstate_fusion() test_fusion_with_nested_sdfg_1() test_parallel_fusion_simple() + test_fuse_indirect_accesses() print("SUCCESS") From a5846b5b1a0e928dabbd9ecf340c19af70e978d5 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 6 Sep 2024 14:14:43 +0200 Subject: [PATCH 066/115] Reworked the serial map fusion and related files. --- dace/transformation/dataflow/buffer_tiling.py | 6 +- .../dataflow/map_fusion_helper.py | 429 +++--------- .../dataflow/map_fusion_parallel.py | 28 +- .../dataflow/map_fusion_serial.py | 618 ++++++++++-------- dace/transformation/dataflow/mapreduce.py | 5 +- tests/transformations/apply_to_test.py | 8 +- tests/transformations/mapfusion_test.py | 4 +- 7 files changed, 487 insertions(+), 611 deletions(-) diff --git a/dace/transformation/dataflow/buffer_tiling.py b/dace/transformation/dataflow/buffer_tiling.py index b7d7a5607b..6fac761175 100644 --- a/dace/transformation/dataflow/buffer_tiling.py +++ b/dace/transformation/dataflow/buffer_tiling.py @@ -100,9 +100,9 @@ def apply(self, graph, sdfg): some_buffer = next(iter(buffers)) # some dummy to pass to MapFusion.apply_to() MapFusion.apply_to( sdfg, - map_exit1=tile_map1_exit, - access_node=some_buffer, - map_entry2=tile_map2_entry, + map_exit_1=tile_map1_exit, + intermediate_access_node=some_buffer, + map_entry_2=tile_map2_entry, verify=True, ) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index 59a4b4040c..4cb1f4ca4a 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -43,6 +43,11 @@ class MapFusionHelper(transformation.SingleStateTransformation): default=False, desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", ) + strict_dataflow = properties.Property( + dtype=bool, + default=False, + desc="If `True` then the transformation will ensure a more stricter data flow.", + ) shared_data = properties.DictProperty( key_type=SDFG, value_type=set, #[str] @@ -54,11 +59,6 @@ class MapFusionHelper(transformation.SingleStateTransformation): " because they transmit data _between states_, such data will be made 'shared'." " This variable acts as a cache, and is managed by 'is_shared_data()'.", ) - strict_dataflow = properties.Property( - dtype=bool, - default=False, - desc="If `True` then the transformation will ensure a more stricter data flow.", - ) def __init__( @@ -80,7 +80,7 @@ def __init__( @classmethod def expressions(cls) -> bool: - raise RuntimeError("The `_MapFusionHelper` is not a transformation on its own.") + raise RuntimeError("The `MapFusionHelper` is not a transformation on its own.") def can_be_fused( @@ -99,9 +99,6 @@ def can_be_fused( - The scheduling of the maps. - The map parameters. - However, for performance reasons, the function does not check if the node - decomposition exists. - Args: map_entry_1: The entry of the first (in serial case the top) map. map_exit_2: The entry of the second (in serial case the bottom) map. @@ -172,6 +169,7 @@ def relocate_nodes( # We now determine which edges we have to migrate, for this we are looking at # the incoming edges, because this allows us also to detect dynamic map ranges. + # TODO(phimuell): If there is already a connection to the node, reuse this. for edge_to_move in list(state.in_edges(from_node)): assert isinstance(edge_to_move.dst_conn, str) @@ -195,23 +193,21 @@ def relocate_nodes( helpers.redirect_edge(state=state, edge=edge_to_move, new_dst=to_node) from_node.remove_in_connector(dmr_symbol) - # There is no other edge that we have to consider, so we just end here - continue - - # We have a Passthrough connection, i.e. there exists a matching `OUT_`. - old_conn = edge_to_move.dst_conn[3:] # The connection name without prefix - new_conn = to_node.next_connector(old_conn) - - to_node.add_in_connector("IN_" + new_conn) - for e in list(state.in_edges_by_connector(from_node, "IN_" + old_conn)): - helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) - to_node.add_out_connector("OUT_" + new_conn) - for e in list(state.out_edges_by_connector(from_node, "OUT_" + old_conn)): - helpers.redirect_edge( - state, e, new_src=to_node, new_src_conn="OUT_" + new_conn - ) - from_node.remove_in_connector("IN_" + old_conn) - from_node.remove_out_connector("OUT_" + old_conn) + else: + # We have a Passthrough connection, i.e. there exists a matching `OUT_`. + old_conn = edge_to_move.dst_conn[3:] # The connection name without prefix + new_conn = to_node.next_connector(old_conn) + + to_node.add_in_connector("IN_" + new_conn) + for e in list(state.in_edges_by_connector(from_node, "IN_" + old_conn)): + helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) + to_node.add_out_connector("OUT_" + new_conn) + for e in list(state.out_edges_by_connector(from_node, "OUT_" + old_conn)): + helpers.redirect_edge( + state, e, new_src=to_node, new_src_conn="OUT_" + new_conn + ) + from_node.remove_in_connector("IN_" + old_conn) + from_node.remove_out_connector("OUT_" + old_conn) # Check if we succeeded. if state.out_degree(from_node) != 0: @@ -313,7 +309,7 @@ def find_parameter_remapping( return None assert len(unused_first_params) == 0 - assert len(final_mapping) == len(first_params) + assert len(final_mapping) == len(unmapped_second_params) return final_mapping @@ -367,7 +363,9 @@ def is_shared_data( Interstate data is used to transmit data between multiple state or by extension within the state. Thus it must be classified as a shared output. This function will go through the SDFG to and collect the names of all data - container that should be classified as shared. + container that should be classified as shared. Note that this is an over + approximation as it does not take the location into account, i.e. "is no longer + used". Args: transient: The transient that should be checked. @@ -375,6 +373,8 @@ def is_shared_data( Note: The function computes the this set once for every SDFG and then caches it. + There is no mechanism to detect if the cache must be evicted. However, + as long as no additional data is added, there is no problem. """ if sdfg not in self.shared_data: self._compute_shared_data(sdfg) @@ -466,8 +466,8 @@ def _compute_multi_write_data( Essentially this function computes the set of data that does not follow the single static assignment idiom. The function also resolves views. - If an access node that refers to a view, the function will add not only - the view itself, but also the data it refers to. + If an access node, refers to a view, not only the view itself, but also + the data it refers to is added to the set. Args: state: The state that should be examined. @@ -475,12 +475,9 @@ def _compute_multi_write_data( Note: This information is used by the partition function (in case strict data - flow mode is enabled), if it is legal to turn a intermediate node into - shared output or if the partition does not exists at all. The current + flow mode is enabled), in strict data flow mode only. The current implementation is rather simple as it only checks if a data is written - to multiple times in the same state. A more refined (but still simple) - implementation would take the location of the access node into - consideration. + to multiple times in the same state. """ data_written_to: Set[str] = set() multi_write_data: Set[str] = set() @@ -488,311 +485,48 @@ def _compute_multi_write_data( for access_node in state.data_nodes(): if state.in_degree(access_node) == 0: continue - if self.is_view(access_node, sdfg): + if access_node.data in data_written_to: + multi_write_data.add(access_node.data) + elif self.is_view(access_node, sdfg): # This is an over approximation. multi_write_data.update([access_node.data, track_view(access_node, state, sdfg).data]) - elif access_node.data in data_written_to: - multi_write_data.add(access_node.data) data_written_to.add(access_node.data) return multi_write_data - def partition_first_outputs( - self, - state: SDFGState, - sdfg: SDFG, - map_exit_1: nodes.MapExit, - map_entry_2: nodes.MapEntry, - ) -> Union[ - Tuple[ - Set[graph.MultiConnectorEdge[dace.Memlet]], - Set[graph.MultiConnectorEdge[dace.Memlet]], - Set[graph.MultiConnectorEdge[dace.Memlet]], - ], - None, - ]: - """Partition the output edges of `map_exit_1` for serial map fusion. - - The output edges of the first map are partitioned into three distinct sets, - defined as follows: - - - Pure Output Set `\mathbb{P}`: - These edges exits the first map and does not enter the second map. These - outputs will be simply be moved to the output of the second map. - - Exclusive Intermediate Set `\mathbb{E}`: - Edges in this set leaves the first map exit, enters an access node, from - where a Memlet then leads immediately to the second map. The memory - referenced by this access node is not used anywhere else, thus it can - be removed. - - Shared Intermediate Set `\mathbb{S}`: - These edges are very similar to the one in `\mathbb{E}` except that they - are used somewhere else, thus they can not be removed and have to be - recreated as output of the second map. - - Returns: - If such a decomposition exists the function will return the three sets - mentioned above in the same order. - In case the decomposition does not exist, i.e. the maps can not be fused - the function returns `None`. - - Args: - state: The in which the two maps are located. - sdfg: The full SDFG in whcih we operate. - map_exit_1: The exit node of the first map. - map_entry_2: The entry node of the second map. - """ - # The three outputs set. - pure_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() - exclusive_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() - shared_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() - - # These are the iteration parameters of the two maps. - # They are not yet modified, that they match each other. - map_params_1: Sequence[str] = map_exit_1.map.params - map_params_2: Sequence[str] = map_entry_2.map.params - - # Compute the renaming that for translating the parameter of the _second_ - # map to the ones used by the first map. - repl_dict: Dict[str, str] = self.find_parameter_remapping( - first_map=map_exit_1.map, - second_map=map_entry_2.map, - ) - assert repl_dict is not None - - # Set of intermediate nodes that we have already processed. - processed_inter_nodes: Set[nodes.Node] = set() - - # These are the data that is written to multiple times in _this_ state. - # If a data is written to multiple time in a state, it could be - # classified as shared. However, it might happen that the node has zero - # degree. This is not a problem as the maps also induced a before-after - # relationship. But some DaCe transformations do not catch this. - # Thus we will never modify such intermediate nodes. - if self.strict_dataflow: - multi_write_data: Set[str] = self._compute_multi_write_data(state, sdfg) - else: - multi_write_data = set() - - # Now scan all output edges of the first exit and classify them - for out_edge in state.out_edges(map_exit_1): - intermediate_node: nodes.Node = out_edge.dst - - # We already processed the node, this should indicate that we should - # run simplify again, or we should start implementing this case. - # TODO(phimuell): Handle this case, it is partially implemented. - if intermediate_node in processed_inter_nodes: - return None - processed_inter_nodes.add(intermediate_node) - - # The intermediate can only have one incoming degree. It might be possible - # to handle multiple incoming edges, if they all come from the top map. - # However, the resulting SDFG might be invalid. - # NOTE: Allow this to happen (under certain cases) if the only producer - # is the top map. - if state.in_degree(intermediate_node) != 1: - return None - - # Now let's look at all nodes that are downstream of the intermediate node. - # This, among other things, will tell us, how we have to handle this node. - # NOTE: The traversal will stop at the second map. - downstream_nodes = self.all_nodes_between( - graph=state, - begin=intermediate_node, - end=map_entry_2, - ) - - # If `downstream_nodes` is `None` this means that `map_entry_2` was never - # reached, thus `intermediate_node` does not enter the second map and - # the node is a pure output node. - if downstream_nodes is None: - pure_outputs.add(out_edge) - continue - - # The following tests are _after_ we have determined if we have a pure - # output node, because this allows us to handle more exotic pure node - # cases, as handling them is essentially rerouting an edge, whereas - # handling intermediate nodes is much more complicated. - - # Checks if the intermediate node refers to data that is accessed by - # _other_ access nodes in _this_ state. If this is the case then never - # touch this intermediate node. - # TODO(phimuell): Technically it would be enough to turn the node into - # a shared output node, because this will still fulfil the dependencies. - # However, some DaCe transformation can not handle this properly, so we - # are _forced_ to reject this node. - if intermediate_node.data in multi_write_data: - return None - - # For us an intermediate node must always be an access node, because - # everything else we do not know how to handle. It is important that - # we do not test for non transient data here, because they can be - # handled has shared intermediates. - if not isinstance(intermediate_node, nodes.AccessNode): - return None - intermediate_desc: data.Data = intermediate_node.desc(sdfg) - if isinstance(intermediate_desc, data.View): - return None - - # Empty Memlets are only allowed if they are in `\mathbb{P}`, which - # is also the only place they really make sense (for a map exit). - # Thus if we now found an empty Memlet we reject it. - if out_edge.data.is_empty(): - return None - - # It can happen that multiple edges converges at the `IN_` connector - # of the first map exit, but there is only one edge leaving the exit. - # It is complicate to handle this, so for now we ignore it. - # TODO(phimuell): Handle this case properly. - # To handle this we need to associate a consumer edge (the outgoing edges - # of the second map) with exactly one producer. - producer_edges: List[graph.MultiConnectorEdge[dace.Memlet]] = list(state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:])) - if len(producer_edges) > 1: - return None - - # Now check the constraints we have on the producers. - # - The source of the producer can not be a view (we do not handle this) - # - The edge shall also not be a reduction edge. - # - Defined location to where they write. - # - No dynamic Melets. - # Furthermore, we will also extract the subsets, i.e. the location they - # modify inside the intermediate array. - # Since we do not allow for WCR, we do not check if the producer subsets intersects. - producer_subsets: List[subsets.Subset] = [] - for producer_edge in producer_edges: - if isinstance(producer_edge.src, nodes.AccessNode) and self.is_view(producer_edge.src, sdfg): - return None - if producer_edge.data.dynamic: - return None - if producer_edge.data.wcr is not None: - return None - if producer_edge.data.dst_subset is None: - return None - producer_subsets.append(producer_edge.data.dst_subset) - - # Now we determine the consumer of nodes. For this we are using the edges - # leaves the second map entry. It is not necessary to find the actual - # consumer nodes, as they might depend on symbols of nested Maps. - # For the covering test we only need their subsets, but we will perform - # some scan and filtering on them. - found_second_map = False - consumer_subsets: List[subsets.Subset] = [] - for intermediate_node_out_edge in state.out_edges(intermediate_node): - - # Ensure that there is no multihop connection to the second map entry. - if intermediate_node_out_edge.dst is not map_entry_2: - if self.all_nodes_between( - graph=state, - begin=intermediate_node_out_edge.dst, - end=map_entry_2, - ) is None: - continue - return None - - # Ensure that the second map is found exactly once. - if found_second_map: - # TODO(phimuell): Lift this restriction. - return None - found_second_map = True - - # Now we look at all edges that leave the second map entry, as they - # define what is read inside the map. - # NOTE: The subset still uses the old iteration variables. - assert intermediate_node_out_edge.dst_conn.startswith("IN_") - for inner_consumer_edge in state.out_edges_by_connector(map_entry_2, "OUT_" + intermediate_node_out_edge.dst_conn[3:]): - if inner_consumer_edge.data.src_subset is None: - return None - if inner_consumer_edge.data.dynamic: - # TODO(phimuell): Is this restriction necessary, I am not sure. - return None - consumer_subsets.append(inner_consumer_edge.data.src_subset) - assert found_second_map, f"Found '{intermediate_node}' which looked like a pure node, but is not one." - assert len(consumer_subsets) != 0 - - # The consumer still uses the original symbols of the second map, so we must rename them. - if repl_dict: - consumer_subsets = copy.deepcopy(consumer_subsets) - for consumer_subset in consumer_subsets: - symbolic.safe_replace(mapping=repl_dict, replace_callback=consumer_subset.replace) - - # Now we are checking if a single iteration of the first (top) map - # can satisfy all data requirements of the second (bottom) map. - # For this we look if the producer covers the consumer. A consumer must - # be covered by exactly one producer. - for consumer_subset in consumer_subsets: - nb_coverings = sum(producer_subset.covers(consumer_subset) for producer_subset in producer_subsets) - if nb_coverings != 1: - return None - - # After we have ensured coverage, we have to decide if the intermediate - # node can be removed (`\mathbb{E}`) or has to be restored (`\mathbb{S}`). - # Note that "removed" here means that it is reconstructed by a new - # output of the second map. - if len(downstream_nodes) != 0: - # The intermediate node is connected to more node inside _this_ state. - shared_outputs.add(out_edge) - elif self.is_shared_data(intermediate_node, sdfg): - # The intermediate data is used somewhere else, either in this or another state. - shared_outputs.add(out_edge) - else: - # The intermediate can be removed, as it is not used anywhere else. - exclusive_outputs.add(out_edge) - - assert exclusive_outputs or shared_outputs or pure_outputs - assert len(processed_inter_nodes) == sum(len(x) for x in [pure_outputs, exclusive_outputs, shared_outputs]) - return (pure_outputs, exclusive_outputs, shared_outputs) - - - def all_nodes_between( + def is_node_reachable_from( self, graph: Union[dace.SDFG, dace.SDFGState], begin: nodes.Node, end: nodes.Node, - reverse: bool = False, - ) -> Union[Set[nodes.Node], None]: - """Find all nodes that are reachable from `begin` but bound by `end`. + ) -> bool: + """Test if the node `end` can be reached from `begin`. Essentially the function starts a DFS at `begin`. If an edge is found that lead - to `end`, this edge is ignored. It will thus found any node that is reachable - from `begin` by a path that does not involve `end`. The returned set will - never contain `end` nor `begin`. In case `end` is never found the function - will return `None`. - - If `reverse` is set to `True` the function will start exploring at `end` and - follows the outgoing edges, i.e. the meaning of `end` and `begin` are swapped. + to `end` the function returns `True`. If the node is never found `False` is + returned. Args: graph: The graph to operate on. begin: The start of the DFS. - end: The terminator node of the DFS. - reverse: Perform a backward DFS. - - Notes: - - The returned set will also contain the nodes of path that starts at - `begin` and ends at a node that is not `end`. + end: The node that should be located. """ - - if reverse: - def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: - return (edge.src for edge in graph.in_edges(node)) - begin, end = end, begin - else: - def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: - return (edge.dst for edge in graph.out_edges(node)) + def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: + return (edge.dst for edge in graph.out_edges(node)) to_visit: List[dace_nodes.Node] = [begin] seen: Set[dace_nodes.Node] = set() while len(to_visit) > 0: node: dace_nodes.Node = to_visit.pop() - if node != end and node not in seen: + if node == end: + return True + elif node not in seen: to_visit.extend(next_nodes(node)) seen.add(node) - # If `end` was not found we have to return `None` to indicate this. - # `begin` and `end` are not included in the output set. - if end not in seen: - return None - return seen - {begin, end} + # We never found `end` + return False def get_access_set( @@ -802,11 +536,10 @@ def get_access_set( ) -> Set[nodes.AccessNode]: """Computes the access set of a "scope node". - If `scope_node` is a `MapEntry` node it will operate on the set of incoming - edges and if it is an `MapExit` node on the set of outgoing edges. The - function will then determine all access nodes that have a connection through - these edges to the scope nodes (edges that does not lead to access nodes are - ignored). + If `scope_node` is a `MapEntry` it will operate on the set of incoming edges + and if it is an `MapExit` on the set of outgoing edges. The function will + then determine all access nodes that have a connection through these edges + to the scope nodes (edges that does not lead to access nodes are ignored). The function returns a set that contains all access nodes that were found. It is important that this set will also contain views. @@ -825,11 +558,58 @@ def get_access_set( for node in map(other_node, get_edges(scope_node)) if isinstance(node, nodes.AccessNode) } - # As far as I know in a valid SDFG this should not happen. - assert len(access_set) == len({node.data for node in access_set}) + + return access_set + def find_subsets( + self, + node: nodes.AccessNode, + scope_node: Union[nodes.MapExit, nodes.MapEntry], + state: SDFGState, + sdfg: SDFG, + repl_dict: Optional[Dict[str, str]], + ) -> List[subsets.Subset]: + """Finds all subsets that access `node` within `scope_node`. + + The function will not start a search for all consumer/producers. + Instead it will locate the edges which is immediately inside the + map scope. + + Args: + node: The access node that should be examined. + scope_node: We are only interested in data that flows through this node. + state: The state in which we operate. + sdfg: The SDFG object. + """ + + # Is the node used for reading or for writing. + # This influences how we have to proceed. + if isinstance(scope_node, nodes.MapEntry): + outer_edges_to_inspect = [e for e in state.in_edges(scope_node) if e.src == node] + get_subset = lambda e: e.data.src_subset + get_inner_edges = lambda e: state.out_edges_by_connector(scope_node, "OUT_" + e.dst_conn[3:]) + else: + outer_edges_to_inspect = [e for e in state.out_edges(scope_node) if e.dst == node] + get_subset = lambda e: e.data.dst_subset + get_inner_edges = lambda e: state.in_edges_by_connector(scope_node, "IN_" + e.src_conn[4:]) + + found_subsets: List[subsets.Subset] = [] + for edge in outer_edges_to_inspect: + found_subsets.extend(get_subset(e) for e in get_inner_edges(edge)) + assert len(found_subsets) > 0, f"Could not find any subsets." + assert not any(subset is None for subset in found_subsets) + + found_subsets = copy.deepcopy(found_subsets) + if repl_dict: + for subset in found_subsets: + # Replace happens in place + symbolic.safe_replace(repl_dict, subset.replace) + + return found_subsets + + def is_view( self, node: nodes.AccessNode, @@ -848,10 +628,9 @@ def track_view( ) -> nodes.AccessNode: """Find the original data of a View. - Given the View `view`, the function will trace the view back to the - original access node. - For convenience, if `view` is not a `View` but a normal data descriptor, - then the function will return the argument unmodified. + Given the View `view`, the function will trace the view back to the original + access node. For convenience, if `view` is not a `View` the argument will be + returned. Args: view: The view that should be traced. diff --git a/dace/transformation/dataflow/map_fusion_parallel.py b/dace/transformation/dataflow/map_fusion_parallel.py index 5a0e0045c9..f78bcd0b76 100644 --- a/dace/transformation/dataflow/map_fusion_parallel.py +++ b/dace/transformation/dataflow/map_fusion_parallel.py @@ -15,7 +15,7 @@ class ParallelMapFusion(map_fusion_helper.MapFusionHelper): """The `ParallelMapFusion` transformation allows to merge two parallel maps. While the `SerialMapFusion` transformation fuses maps that are sequentially - connected by an intermediate node, this transformation is able to fuse any + connected through an intermediate node, this transformation is able to fuse any two maps that are not sequential and in the same scope. Args: @@ -29,8 +29,8 @@ class ParallelMapFusion(map_fusion_helper.MapFusionHelper): modify the exit nodes of the Maps. """ - map_entry1 = transformation.transformation.PatternNode(nodes.MapEntry) - map_entry2 = transformation.transformation.PatternNode(nodes.MapEntry) + map_entry_1 = transformation.transformation.PatternNode(nodes.MapEntry) + map_entry_2 = transformation.transformation.PatternNode(nodes.MapEntry) only_if_common_ancestor = properties.Property( dtype=bool, @@ -54,7 +54,7 @@ def __init__( def expressions(cls) -> Any: # This just matches _any_ two Maps inside a state. state = graph.OrderedMultiDiConnectorGraph() - state.add_nodes_from([cls.map_entry1, cls.map_entry2]) + state.add_nodes_from([cls.map_entry_1, cls.map_entry_2]) return [state] @@ -65,8 +65,12 @@ def can_be_applied( sdfg: dace.SDFG, permissive: bool = False, ) -> bool: - map_entry_1: nodes.MapEntry = self.map_entry1 - map_entry_2: nodes.MapEntry = self.map_entry2 + """Checks if the fusion can be done. + + The function checks the general fusing conditions and if the maps are parallel. + """ + map_entry_1: nodes.MapEntry = self.map_entry_1 + map_entry_2: nodes.MapEntry = self.map_entry_2 # Check the structural properties of the maps, this will also ensure that # the two maps are in the same scope and the parameters can be renamed @@ -82,7 +86,7 @@ def can_be_applied( # Since the match expression matches any twp Maps, we have to ensure that # the maps are parallel. The `can_be_fused()` function already verified # if they are in the same scope. - if not self._is_parallel(graph=graph, node1=map_entry_1, node2=map_entry_2): + if not self.is_parallel(graph=graph, node1=map_entry_1, node2=map_entry_2): return False # Test if they have they share a node as direct ancestor. @@ -95,7 +99,7 @@ def can_be_applied( return True - def _is_parallel( + def is_parallel( self, graph: SDFGState, node1: nodes.Node, @@ -119,9 +123,9 @@ def _is_parallel( # The `all_nodes_between()` function traverse the graph and returns `None` if # `end` was not found. We have to call it twice, because we do not know # which node is upstream if they are not parallel. - if self.all_nodes_between(graph=graph, begin=node1, end=node2) is not None: + if self.is_node_reachable_from(graph=graph, begin=node1, end=node2): return False - elif self.all_nodes_between(graph=graph, begin=node2, end=node1) is not None: + elif self.is_node_reachable_from(graph=graph, begin=node2, end=node1): return False return True @@ -133,9 +137,9 @@ def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: and `MapExit`) of the second map to the scope nodes of the first map. """ - map_entry_1: nodes.MapEntry = self.map_entry1 + map_entry_1: nodes.MapEntry = self.map_entry_1 map_exit_1: nodes.MapExit = graph.exit_node(map_entry_1) - map_entry_2: nodes.MapEntry = self.map_entry2 + map_entry_2: nodes.MapEntry = self.map_entry_2 map_exit_2: nodes.MapExit = graph.exit_node(map_entry_2) # Before we do anything we perform the renaming. diff --git a/dace/transformation/dataflow/map_fusion_serial.py b/dace/transformation/dataflow/map_fusion_serial.py index 4b13a36bc1..56175879c6 100644 --- a/dace/transformation/dataflow/map_fusion_serial.py +++ b/dace/transformation/dataflow/map_fusion_serial.py @@ -3,10 +3,10 @@ """Implements the serial map fusing transformation.""" import copy -from typing import Any, Dict, List, Set, Union, Optional +from typing import Any, Dict, List, Set, Tuple, Union, Optional import dace -from dace import dtypes, properties, subsets, symbolic, transformation +from dace import data, dtypes, properties, subsets, symbolic, transformation from dace.sdfg import SDFG, SDFGState, graph, nodes from dace.transformation.dataflow import map_fusion_helper as mfh @@ -14,38 +14,42 @@ @properties.make_properties class SerialMapFusion(mfh.MapFusionHelper): - """Specialized replacement for the map fusion transformation that is provided by DaCe. + """Fuse two serial maps together. - As its name is indicating this transformation is only able to handle Maps that - are in sequence. Compared to the native DaCe transformation, this one is able - to handle more complex cases of connection between the maps. In that sense, it - is much more similar to DaCe's `SubgraphFusion` transformation. + The transformation combines two maps into one that are connected through some + access nodes. Conceptually this transformation removes the exit of the first + or upper map and the entry of the lower or second map and then rewrites the + connections appropriately. Depending on the situation the transformation will + either fully remove or make the intermediate a new output of the second map. - Things that are improved, compared to the native DaCe implementation: - - Nested Maps. - - Temporary arrays and the correct propagation of their Memlets. - - Top Maps that have multiple outputs. - - Conceptually this transformation removes the exit of the first or upper map - and the entry of the lower or second map and then rewrites the connections - appropriately. - - This transformation assumes that an SDFG obeys the structure that is outlined - [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). For that - reason it is not true replacement of the native DaCe transformation. + By default, the transformation does not use the strict data flow mode, see + `MapFusionHelper` for more, however, it might be useful in come cases to enable + it. Args: only_inner_maps: Only match Maps that are internal, i.e. inside another Map. only_toplevel_maps: Only consider Maps that are at the top. + strict_dataflow: If `True`, the transformation ensures a more + stricter version of the data flow. Notes: - - This transformation modifies more nodes than it matches! - - Run simplify to get ri of excess keep alive nodes + - This transformation modifies more nodes than it matches. + - After the transformation has been applied simplify should be run to remove + some dead data flow, that was introduced to ensure validity. + - A `SerialMapFusion` obejct can be initialized and be reused. However, + after new access nodes are added to any state, it is no longer valid + to use the object. + + Todo: + - Consider the case that only shared nodes are created (thus no inspection of + the graph is needed) and make all shared. Then use the dead dataflow + elimination transformation to get rid of the ones we no longer need. + - Increase the applicability. """ - map_exit1 = transformation.transformation.PatternNode(nodes.MapExit) - access_node = transformation.transformation.PatternNode(nodes.AccessNode) - map_entry2 = transformation.transformation.PatternNode(nodes.MapEntry) + map_exit_1 = transformation.transformation.PatternNode(nodes.MapExit) + intermediate_access_node = transformation.transformation.PatternNode(nodes.AccessNode) + map_entry_2 = transformation.transformation.PatternNode(nodes.MapEntry) def __init__( self, @@ -64,7 +68,7 @@ def expressions(cls) -> Any: matched nodes, but more or less on anything that has an incoming connection from the first Map or an outgoing connection to the second Map entry. """ - return [dace.sdfg.utils.node_path_graph(cls.map_exit1, cls.access_node, cls.map_entry2)] + return [dace.sdfg.utils.node_path_graph(cls.map_exit_1, cls.intermediate_access_node, cls.map_entry_2)] def can_be_applied( @@ -77,14 +81,13 @@ def can_be_applied( """Tests if the matched Maps can be merged. The two Maps are mergeable iff: - - The `can_be_fused()` of the base succeed, which checks some basic constraints. - - The decomposition exists and at least one of the intermediate sets - is not empty. + - Checks general requirements, see `MapFusionHelper.can_be_fused()`. + - Tests if the decomposition exists. + - Tests if there are read write dependencies. """ - assert isinstance(self.map_exit1, nodes.MapExit) - assert isinstance(self.map_entry2, nodes.MapEntry) - map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit1) - map_entry_2: nodes.MapEntry = self.map_entry2 + map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit_1) + map_exit_1: nodes.MapExit = self.map_exit_1 + map_entry_2: nodes.MapEntry = self.map_entry_2 # This essentially test the structural properties of the two Maps. if not self.can_be_fused( @@ -107,8 +110,8 @@ def can_be_applied( output_partition = self.partition_first_outputs( state=graph, sdfg=sdfg, - map_exit_1=self.map_exit1, - map_entry_2=self.map_entry2, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, ) if output_partition is None: return False @@ -127,20 +130,25 @@ def has_read_write_dependency( ) -> bool: """Test if there is a read write dependency between the two maps to be fused. - The function first looks at the set of data that is read/written by the - two maps. If the function detects a possible conflict, the function will - evaluate the subsets of the read and write to determine if the conflict - can be resolved or not. + The function checks two different things. + - The function will make sure that there is no read write dependency between + the input and output of the fused maps. For that it will inspect the + respective subsets. + - The second part partially checks the intermediate nodes, it mostly ensures + that there are not views and that they are not used as inputs or outputs + at the same time. However, the function will not check for read write + conflicts in this set, this is done in the partition function. Returns: `True` if there is a conflict between the maps that can not be handled. - If there is no conflict or if the conflict can be handled then `False` + If there is no conflict or if the conflict can be handled `False` is returned. Args: map_entry_1: The entry node of the first map. map_entry_2: The entry node of the second map. state: The state on which we operate. + sdfg: The SDFG on which we operate. """ map_exit_1: nodes.MapExit = state.exit_node(map_entry_1) map_exit_2: nodes.MapExit = state.exit_node(map_entry_2) @@ -149,13 +157,20 @@ def has_read_write_dependency( # are not resolved yet. access_sets: List[Dict[str, nodes.AccessNode]] = [] for scope_node in [map_entry_1, map_exit_1, map_entry_2, map_exit_2]: + access_set: Set[nodes.AccessNode] = self.get_access_set(scope_node, state) access_sets.append({ node.data: node - for node in self.get_access_set(scope_node, state) + for node in access_set }) + # If two different access nodes of the same scoping node refers to the + # same data, then we consider this as a dependency we can not handle. + # It is only a problem for the intermediate nodes and might be possible + # to handle, but doing so is hard, so we just forbid it. + if len(access_set) != len(access_sets[-1]): + return True read_map_1, write_map_1, read_map_2, write_map_2 = access_sets - # It might be possible that there are views, so we have to resolve these sets. + # It might be possible that there are views, so we have to resolve them. # We also already get the name of the data container. # Note that `len(real_read_map_1) <= len(read_map_1)` holds because of Views. resolved_sets: List[Set[str]] = [] @@ -164,170 +179,63 @@ def has_read_write_dependency( self.track_view(node).data if self.is_view(node, sdfg) else node.data for node in unresolved_set.values() }) + # If the resolved and unresolved names do not have the same length. + # Then different views point to the same location, which we forbid + if len(unresolved_set) != len(resolved_sets[-1]): + return None real_read_map_1, real_write_map_1, real_read_map_2, real_write_map_2 = resolved_sets - # If the resolved and the unresolved set of input/output have different lengths, - # it means that there were two different views that ultimately referred to the - # same data. - for unresolved_access, resolved_access in zip(access_sets, resolved_sets): - if len(unresolved_access) != len(resolved_access): - return True - # We do not allow that the first and second map each write to the same data. - # The reason is because it is very hard to handle correctly. if not real_write_map_1.isdisjoint(real_write_map_2): return True - # The inputs and outputs are different, so there can not be any conflicts. - # Must be done after the test above! + # If there is no overlap in what is (totally) read and written, there will be no conflict. + # This must come before the check of disjoint write. if (real_read_map_1 | real_read_map_2).isdisjoint(real_write_map_1 | real_write_map_2): return False - # This is the set of nodes that is used to exchange data between the two maps. - # This works, because in the partition function we ensure that such nodes are - # directly connected. - read_map_2_nodes: Set[node.AccessNode] = set(read_map_2.values()) - exchange_nodes: Dict[str, nodes.AccessNode] = { - name: node - for name, node in write_map_1.items() - if node in read_map_2_nodes - } - - # For simplicity we assume that the nodes used to exchange information can - # not be a View. This is a simplification. - if any(self.is_view(exchange_node, sdfg) for exchange_node in exchange_nodes.values()): + # These are the names (unresolved) and the access nodes of the data that is used + # to transmit information between the maps. The partition function ensures that + # these nodes are directly connected to the two maps. + exchange_names: Set[str] = set(write_map_1.keys()).intersection(read_map_2.keys()) + exchange_nodes: Set[node.AccessNode] = set(write_map_1.values()).intersection(read_map_2.values()) + + # If the number are different then a data is accessed through multiple nodes. + if len(exchange_names) != len(exchange_nodes): + return True + assert all(exchange_node.data in exchange_names for exchange_node in exchange_nodes) + + # For simplicity we assume that the nodes used for exchange are not views. + if any(self.is_view(exchange_node, sdfg) for exchange_node in exchange_nodes): return True # This is the names of the node that are used as input of the first map and - # as output of the second map. We can not use the resolved here, because - # we forbid that these nodes are Views. + # as output of the second map. We have to ensure that there is no data + # dependency between these nodes. fused_inout_data_names: Set[str] = set(read_map_1.keys()).intersection(write_map_2.keys()) - # Because it is hard, we do not allow Views here, because we can not resolve - # access sets (at least I can not). + # If a data container is used as input and output then it can not be a view (simplicity) if any(self.is_view(read_map_1[name], sdfg) for name in fused_inout_data_names): return True - # This is a case that can not be handled, the above code should filter this - # out, so if you are here, then the above code might have problems, - # furthermore the code below assumes it. - assert fused_inout_data_names.isdisjoint(exchange_nodes.keys()), "Constraint violation." - assert (len(fused_inout_data_names) > 0) or (len(exchange_nodes) > 0) - - # We will now inspect the subsets. - repl_dict = self.find_parameter_remapping(map_entry_1.map, map_entry_2.map) - - # First we handle the rw dependency that is given by the whole fused map. - if not self._check_read_write_dependency_fused_map( - map_entry_1=map_entry_1, - map_exit_2=map_exit_2, - inout_data_names=fused_inout_data_names, - read_map_1=read_map_1, - write_map_2=write_map_2, - repl_dict=repl_dict, - state=state, - sdfg=sdfg): - return True # There are rw dependencies. - - # Now we check the exchange nodes, i.e. the common nodes between the maps, - # are point wise. - if not self._check_read_write_dependency_exchange_nodes( - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, - exchange_nodes=exchange_nodes, - repl_dict=repl_dict, - state=state, - sdfg=sdfg, - ): - return True # There are rw dependencies. - - # No read write dependency was found. - return False - - - def _check_read_write_dependency_exchange_nodes( - self, - map_exit_1: nodes.MapExit, - map_entry_2: nodes.MapEntry, - exchange_nodes: Dict[str, nodes.AccessNode], - repl_dict: Union[Dict[str, str], None], - state: SDFGState, - sdfg: SDFG, - ) -> bool: - """Checks if there are any read after write dependencies in the exchange set. - - Args: - map_exit_1: Exit node of the first (top) map; defines writes. - map_entry_2: Entry node of the second (bottom) map; defines reads. - exchange_nodes: Exchange nodes, i.e. written and read by the maps. - repl_dict: Replacement dict, for renaming the subsets of the second map. - state: The state in which we operate. - sdfg: The containing SDFG. - """ - - for exchange_node in exchange_nodes.values(): - all_subsets: List[subsets.Subset] = [] - - # The reading subsets are defined by the entry of the second map, - # thus we also have to perform some replacing of the parameters. - all_subsets.extend( - self._find_subsets( - node=exchange_node, - scope_node=map_entry_2, - state=state, - sdfg=sdfg, - repl_dict=repl_dict, - ) - ) - - # The writing subset is given by the exit of the first map. No replacing - # is needed, but the node is the same. - all_subsets.extend( - self._find_subsets( - node=exchange_node, - scope_node=map_exit_1, - state=state, - sdfg=sdfg, - repl_dict=None, - ) - ) - - if not self._test_if_subsets_are_point_wise(all_subsets): - return False - - # All subsets are point wise - return True - + # A data container can be used as input and output. But we do not allow that + # it is also used as intermediate or exchange data. This is an important check. + if not fused_inout_data_names.isdisjoint(exchange_names): + return True - def _check_read_write_dependency_fused_map( - self, - map_entry_1: nodes.MapEntry, - map_exit_2: nodes.MapExit, - inout_data_names: Set[str], - read_map_1: Dict[str, nodes.AccessNode], - write_map_2: Dict[str, nodes.AccessNode], - repl_dict: Union[Dict[str, str], None], - state: SDFGState, - sdfg: SDFG, - ) -> bool: - """Checks the read write dependency that are given by the fused map. + # Get the replacement dict for changing the map variables from the subsets of + # the second map. + repl_dict = self.find_parameter_remapping(map_entry_1.map, map_exit_2.map) - Args: - map_entry_1: The map entry node of the first (top) map. - map_exit_2: The map exit node of the second (bottom) map. - inout_data_names: Names of all data containers that are conflicting. - read_map_1: All access nodes from which the first map reads (`node.data -> node`). - write_map_2: All access nodes to which the second map writes (`node.data -> node`). - repl_dict: Replacement dict for renaming the second maps iteration parameters. - state: The state in which we operate. - sdfg: The containing SDFG. - """ - for inout_data_name in inout_data_names: + # Now we inspect if there is a read write dependency, between data that is + # used as input and output of the fused map. There is no problem is they + # are pointwise, i.e. in each iteration the same locations are accessed. + # Essentially they all boil down to `a += 1`. + for inout_data_name in fused_inout_data_names: all_subsets: List[subsets.Subset] = [] - # The subsets that define reading are given by the first map's entry node all_subsets.extend( - self._find_subsets( + self.find_subsets( node=read_map_1[inout_data_name], scope_node=map_entry_1, state=state, @@ -335,11 +243,10 @@ def _check_read_write_dependency_fused_map( repl_dict=None, ) ) - # While the subsets defining writing are given by the second map's exit # node, there we also have to apply renaming. all_subsets.extend( - self._find_subsets( + self.find_subsets( node=write_map_2[inout_data_name], scope_node=map_exit_2, state=state, @@ -347,17 +254,15 @@ def _check_read_write_dependency_fused_map( repl_dict=repl_dict, ) ) - # Now we can test if these subsets are point wise - if not self._test_if_subsets_are_point_wise(all_subsets): - return False - - # All subsets are point wise - return True + if not self.test_if_subsets_are_point_wise(all_subsets): + return True + # No read write dependency was found. + return False - def _test_if_subsets_are_point_wise( + def test_if_subsets_are_point_wise( self, subsets_to_check: List[subsets.Subset] ) -> bool: @@ -400,60 +305,258 @@ def _test_if_subsets_are_point_wise( return True - def _find_subsets( - self, - node: nodes.AccessNode, - scope_node: Union[nodes.MapExit, nodes.MapEntry], - state: SDFGState, - sdfg: SDFG, - repl_dict: Optional[Dict[str, str]], - ) -> List[subsets.Subset]: - """Finds all subsets involving node `node`. - - The function will not start a search for all consumer/producers. - Instead it will locate the edges which is immediately inside the - map scope. + def partition_first_outputs( + self, + state: SDFGState, + sdfg: SDFG, + map_exit_1: nodes.MapExit, + map_entry_2: nodes.MapEntry, + ) -> Union[ + Tuple[ + Set[graph.MultiConnectorEdge[dace.Memlet]], + Set[graph.MultiConnectorEdge[dace.Memlet]], + Set[graph.MultiConnectorEdge[dace.Memlet]], + ], + None, + ]: + """Partition the output edges of `map_exit_1` for serial map fusion. + + The output edges of the first map are partitioned into three distinct sets, + defined as follows: + - Pure Output Set `\mathbb{P}`: + These edges exits the first map and does not enter the second map. These + outputs will be simply be moved to the output of the second map. + - Exclusive Intermediate Set `\mathbb{E}`: + Edges in this set leaves the first map exit, enters an access node, from + where a Memlet then leads immediately to the second map. The memory + referenced by this access node is not used anywhere else, thus it can + be removed. + - Shared Intermediate Set `\mathbb{S}`: + These edges are very similar to the one in `\mathbb{E}` except that they + are used somewhere else, thus they can not be removed and have to be + recreated as output of the second map. + + If strict data flow mode is enabled the function is rather strict if an + output can be added to either intermediate set and might fail to compute + the partition, even if it would exist. + + Returns: + If such a decomposition exists the function will return the three sets + mentioned above in the same order. + In case the decomposition does not exist, i.e. the maps can not be fused + the function returns `None`. Args: - node: The access node that should be examined. - scope_node: We are only interested in data that flows through this node. - state: The state in which we operate. - sdfg: The SDFG object. + state: The in which the two maps are located. + sdfg: The full SDFG in whcih we operate. + map_exit_1: The exit node of the first map. + map_entry_2: The entry node of the second map. """ - - # Is the node used for reading or for writing. - # This influences how we have to proceed. - if isinstance(scope_node, nodes.MapEntry): - used_for_reading = True - edges_to_inspect = state.in_edges(scope_node) - test_edge = lambda e: (e.src == node) - get_subset = lambda e: e.data.src_subset + # The three outputs set. + pure_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + exclusive_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + shared_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + + # These are the iteration parameters of the two maps. + # They are not yet modified, that they match each other. + map_params_1: Sequence[str] = map_exit_1.map.params + map_params_2: Sequence[str] = map_entry_2.map.params + + # Compute the renaming that for translating the parameter of the _second_ + # map to the ones used by the first map. + repl_dict: Dict[str, str] = self.find_parameter_remapping( + first_map=map_exit_1.map, + second_map=map_entry_2.map, + ) + assert repl_dict is not None + + # Set of intermediate nodes that we have already processed. + processed_inter_nodes: Set[nodes.Node] = set() + + # These are the data that is written to multiple times in _this_ state. + # If a data is written to multiple time in a state, it could be + # classified as shared. However, it might happen that the node has zero + # degree. This is not a problem as the maps also induced a before-after + # relationship. But some DaCe transformations do not catch this. + # Thus we will never modify such intermediate nodes and fail instead. + if self.strict_dataflow: + multi_write_data: Set[str] = self._compute_multi_write_data(state, sdfg) else: - used_for_reading = False - edges_to_inspect = state.out_edges(scope_node) - test_edge = lambda e: (e.dst == node) - get_subset = lambda e: e.data.dst_subset - - found_subsets: List[subsets.Subset] = [] - for edge in edges_to_inspect: - if not test_edge(edge): + multi_write_data = set() + + # Now scan all output edges of the first exit and classify them + for out_edge in state.out_edges(map_exit_1): + intermediate_node: nodes.Node = out_edge.dst + + # We already processed the node, this should indicate that we should + # run simplify again, or we should start implementing this case. + # TODO(phimuell): Handle this case, already partially handled here. + if intermediate_node in processed_inter_nodes: + return None + processed_inter_nodes.add(intermediate_node) + + # The intermediate can only have one incoming degree. It might be possible + # to handle multiple incoming edges, if they all come from the top map. + # However, the resulting SDFG might be invalid. + # NOTE: Allow this to happen (under certain cases) if the only producer + # is the top map. + if state.in_degree(intermediate_node) != 1: + return None + + # If the second map is not reachable from the intermediate node, then + # the output is pure and we can end here. + if not self.is_node_reachable_from( + graph=state, + begin=intermediate_node, + end=map_entry_2, + ): + pure_outputs.add(out_edge) continue - if used_for_reading: - inner_edges = state.out_edges_by_connector(scope_node, "OUT_" + edge.dst_conn[3:]) + + # The following tests are _after_ we have determined if we have a pure + # output node, because this allows us to handle more exotic pure node + # cases, as handling them is essentially rerouting an edge, whereas + # handling intermediate nodes is much more complicated. + + # For us an intermediate node must always be an access node, because + # everything else we do not know how to handle. It is important that + # we do not test for non transient data here, because they can be + # handled has shared intermediates. + if not isinstance(intermediate_node, nodes.AccessNode): + return None + intermediate_desc: data.Data = intermediate_node.desc(sdfg) + if self.is_view(intermediate_node, sdfg): + return None + + # Checks if the intermediate node refers to data that is accessed by + # _other_ access nodes in _this_ state. If this is the case then never + # touch this intermediate node. + # TODO(phimuell): Technically it would be enough to turn the node into + # a shared output node, because this will still fulfil the dependencies. + # However, some DaCe transformation can not handle this properly, so we + # are _forced_ to reject this node. + if intermediate_node.data in multi_write_data: + return None + + # Empty Memlets are only allowed if they are in `\mathbb{P}`, which + # is also the only place they really make sense (for a map exit). + # Thus if we now found an empty Memlet we reject it. + if out_edge.data.is_empty(): + return None + + # It can happen that multiple edges converges at the `IN_` connector + # of the first map exit, but there is only one edge leaving the exit. + # It is complicate to handle this, so for now we ignore it. + # TODO(phimuell): Handle this case properly. + # To handle this we need to associate a consumer edge (the outgoing edges + # of the second map) with exactly one producer. + producer_edges: List[graph.MultiConnectorEdge[dace.Memlet]] = list(state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:])) + if len(producer_edges) > 1: + return None + + # Now check the constraints we have on the producers. + # - The source of the producer can not be a view (we do not handle this) + # - The edge shall also not be a reduction edge. + # - Defined location to where they write. + # - No dynamic Melets. + # Furthermore, we will also extract the subsets, i.e. the location they + # modify inside the intermediate array. + # Since we do not allow for WCR, we do not check if the producer subsets intersects. + producer_subsets: List[subsets.Subset] = [] + for producer_edge in producer_edges: + if isinstance(producer_edge.src, nodes.AccessNode) and self.is_view(producer_edge.src, sdfg): + return None + if producer_edge.data.dynamic: + return None + if producer_edge.data.wcr is not None: + return None + if producer_edge.data.dst_subset is None: + return None + producer_subsets.append(producer_edge.data.dst_subset) + + # Check if the producer do not intersect + if len(producer_subsets) == 1: + pass + elif len(producer_subsets) == 2: + if producer_subsets[0].intersects(producer_subsets[1]): + return None else: - inner_edges = state.in_edges_by_connector(scope_node, "IN_" + edge.src_conn[4:]) - found_subsets.extend(get_subset(e) for e in inner_edges) - assert len(found_subsets) > 0, f"Could not find any subsets." - assert not any(subset is None for subset in found_subsets) + for i, psbs1 in enumerate(producer_subsets): + for j, psbs2 in enumerate(producer_subsets): + if i == j: + continue + if psbs1.intersects(psbs2): + return None + + # Now we determine the consumer of nodes. For this we are using the edges + # leaves the second map entry. It is not necessary to find the actual + # consumer nodes, as they might depend on symbols of nested Maps. + # For the covering test we only need their subsets, but we will perform + # some scan and filtering on them. + found_second_map = False + consumer_subsets: List[subsets.Subset] = [] + for intermediate_node_out_edge in state.out_edges(intermediate_node): + + # If the second map entry is not immediately reachable from the intermediate + # node, then ensure that there is not path that goes to it. + if intermediate_node_out_edge.dst is not map_entry_2: + if self.is_node_reachable_from(graph=state, begin=intermediate_node_out_edge.dst, end=map_entry_2): + return None + continue - # The deepcopy is needed if we would do renaming. - found_subsets = copy.deepcopy(found_subsets) - if repl_dict: - for subset in found_subsets: - # Replace happens in place - symbolic.safe_replace(repl_dict, subset.replace) + # Ensure that the second map is found exactly once. + # TODO(phimuell): Lift this restriction. + if found_second_map: + return None + found_second_map = True + + # The output of the top map can not define a dynamic map range in the + # second map. + if not intermediate_node_out_edge.dst_conn.startswith("IN_"): + return None + + # Now we look at all edges that leave the second map entry, i.e. the + # edges that feeds the consumer and define what is read inside the map. + # We do not check them, but collect them and inspect them. + # NOTE: The subset still uses the old iteration variables. + for inner_consumer_edge in state.out_edges_by_connector(map_entry_2, "OUT_" + intermediate_node_out_edge.dst_conn[3:]): + if inner_consumer_edge.data.src_subset is None: + return None + if inner_consumer_edge.data.dynamic: + # TODO(phimuell): Is this restriction necessary, I am not sure. + return None + consumer_subsets.append(inner_consumer_edge.data.src_subset) + assert found_second_map, f"Found '{intermediate_node}' which looked like a pure node, but is not one." + assert len(consumer_subsets) != 0 + + # The consumer still uses the original symbols of the second map, so we must rename them. + if repl_dict: + consumer_subsets = copy.deepcopy(consumer_subsets) + for consumer_subset in consumer_subsets: + symbolic.safe_replace(mapping=repl_dict, replace_callback=consumer_subset.replace) + + # Now we are checking if a single iteration of the first (top) map + # can satisfy all data requirements of the second (bottom) map. + # For this we look if the producer covers the consumer. A consumer must + # be covered by exactly one producer. + for consumer_subset in consumer_subsets: + nb_coverings = sum(producer_subset.covers(consumer_subset) for producer_subset in producer_subsets) + if nb_coverings != 1: + return None + + # After we have ensured coverage, we have to decide if the intermediate + # node can be removed (`\mathbb{E}`) or has to be restored (`\mathbb{S}`). + # Note that "removed" here means that it is reconstructed by a new + # output of the second map. + if self.is_shared_data(intermediate_node, sdfg): + # The intermediate data is used somewhere else, either in this or another state. + shared_outputs.add(out_edge) + else: + # The intermediate can be removed, as it is not used anywhere else. + exclusive_outputs.add(out_edge) - return found_subsets + assert len(processed_inter_nodes) == sum(len(x) for x in [pure_outputs, exclusive_outputs, shared_outputs]) + return (pure_outputs, exclusive_outputs, shared_outputs) def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: @@ -473,13 +576,13 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non # once we start adding and removing nodes it seems that their ID changes. # Thus we have to save them here, this is a known behaviour in DaCe. assert isinstance(graph, dace.SDFGState) - assert isinstance(self.map_exit1, nodes.MapExit) - assert isinstance(self.map_entry2, nodes.MapEntry) + assert isinstance(self.map_exit_1, nodes.MapExit) + assert isinstance(self.map_entry_2, nodes.MapEntry) - map_exit_1: nodes.MapExit = self.map_exit1 - map_entry_2: nodes.MapEntry = self.map_entry2 - map_exit_2: nodes.MapExit = graph.exit_node(self.map_entry2) - map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit1) + map_exit_1: nodes.MapExit = self.map_exit_1 + map_entry_2: nodes.MapEntry = self.map_entry_2 + map_exit_2: nodes.MapExit = graph.exit_node(self.map_entry_2) + map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit_1) # Before we do anything we perform the renaming. self.rename_map_parameters( @@ -596,16 +699,15 @@ def handle_intermediate_set( pre_exit_edge = pre_exit_edges[0] new_inter_shape_raw = symbolic.overapproximate(pre_exit_edge.data.subset.size()) + # Over approximation will leave us with some unneeded size one dimensions. + # If they are removed some dace transformations (especially auto optimization) + # will have problems. if not self.strict_dataflow: - # Over approximation will leave us with some unneeded size one dimensions. - # If they are removed some dace transformations (especially auto optimization) - # will have problems. squeezed_dims: List[int] = [] # These are the dimensions we removed. new_inter_shape: List[int] = [] # This is the final shape of the new intermediate. for dim, (proposed_dim_size, full_dim_size) in enumerate( zip(new_inter_shape_raw, inter_shape) ): - # Order of checks is important! if full_dim_size == 1: # Must be kept! new_inter_shape.append(proposed_dim_size) elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. @@ -616,8 +718,6 @@ def handle_intermediate_set( squeezed_dims = [] new_inter_shape = list(new_inter_shape_raw) - - # This is the name of the new "intermediate" node that we will create. # It will only have the shape `new_inter_shape` which is basically its # output within one Map iteration. @@ -677,13 +777,9 @@ def handle_intermediate_set( # We now handle the MemletTree defined by this edge. # The newly created edge, only handled the last collection step. - for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children(include_self=True): + for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children(include_self=False): producer_edge = producer_tree.edge - # Exclude the edge we just have created. - if producer_edge is new_pre_exit_edge: - continue - # Associate the (already existing) Memlet with the new data. # TODO(phimuell): Improve the code below to remove the check. assert producer_edge.data.data == inter_name @@ -727,7 +823,7 @@ def handle_intermediate_set( # As for the producer side, we now read from a smaller array, # So we must offset them, we use the original edge for this. assert inner_edge.data.src_subset is not None - inner_edge_correction_offset = inner_edge.data.src_subset + inner_edge_correction_offset = copy.deepcopy(inner_edge.data.src_subset) # Now we create a new connection that instead reads from the new # intermediate, instead of the old one. For this we use the @@ -754,21 +850,19 @@ def handle_intermediate_set( new_inner_memlet.src_subset.pop(squeezed_dims) # Now we have to make sure that all consumers are properly updated. - for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children(include_self=True): - if consumer_tree.edge is new_inner_edge: - continue + for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children(include_self=False): assert consumer_tree.edge.data.data == inter_name consumer_edge = consumer_tree.edge consumer_edge.data.data = new_inter_name if is_scalar: consumer_edge.data.src_subset = "0" - elif consumer_edge.data.subset is not None: + elif consumer_edge.data.src_subset is not None: consumer_edge.data.src_subset.offset(inner_edge_correction_offset, negative=True) consumer_edge.data.src_subset.pop(squeezed_dims) - # The edge that leaves the second map entry was already deleted. - # We will now delete the edges that brought the data. + # The edge that leaves the second map entry was already deleted. We now delete + # the edges that connected the intermediate node with the second map entry. for edge in list(state.in_edges_by_connector(map_entry_2, in_conn_name)): assert edge.src == inter_node state.remove_edge(edge) diff --git a/dace/transformation/dataflow/mapreduce.py b/dace/transformation/dataflow/mapreduce.py index 11445344ec..de76eee0ba 100644 --- a/dace/transformation/dataflow/mapreduce.py +++ b/dace/transformation/dataflow/mapreduce.py @@ -216,9 +216,8 @@ def apply(self, graph: SDFGState, sdfg: SDFG): map_entry, _ = map_collapse.apply(graph, sdfg) map_fusion = MapFusion() - # What is with the array? map_fusion.setup_match(sdfg, self.cfg_id, self.state_id, { - MapFusion.map_exit1: graph.node_id(self.tmap_exit), - MapFusion.map_entry2: graph.node_id(map_entry), + MapFusion.map_exit_1: graph.node_id(self.tmap_exit), + MapFusion.map_entry_2: graph.node_id(map_entry), }, 0) map_fusion.apply(graph, sdfg) diff --git a/tests/transformations/apply_to_test.py b/tests/transformations/apply_to_test.py index 2b1828f5ba..47854a1948 100644 --- a/tests/transformations/apply_to_test.py +++ b/tests/transformations/apply_to_test.py @@ -31,7 +31,7 @@ def test_applyto_pattern(): transient = next(aname for aname, desc in sdfg.arrays.items() if desc.transient) access_node = next(n for n in state.nodes() if isinstance(n, dace.nodes.AccessNode) and n.data == transient) - MapFusion.apply_to(sdfg, map_exit1=mult_exit, access_node=access_node, map_entry2=add_entry) + MapFusion.apply_to(sdfg, map_exit_1=mult_exit, intermediate_access_node=access_node, map_entry_2=add_entry) def test_applyto_enumerate(): @@ -42,9 +42,9 @@ def test_applyto_enumerate(): pattern = sdutil.node_path_graph(dace.nodes.MapExit, dace.nodes.AccessNode, dace.nodes.MapEntry) for subgraph in enumerate_matches(sdfg, pattern): MapFusion.apply_to(sdfg, - map_exit1=subgraph.source_nodes()[0], - access_node=next(n for n in subgraph.nodes() if isinstance(n, dace.nodes.AccessNode)), - map_entry2=subgraph.sink_nodes()[0]) + map_exit_1=subgraph.source_nodes()[0], + intermediate_access_node=next(n for n in subgraph.nodes() if isinstance(n, dace.nodes.AccessNode)), + map_entry_2=subgraph.sink_nodes()[0]) def test_applyto_subgraph(): diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 0f444f5382..d828da1242 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -568,8 +568,8 @@ def _make_sdfg(): ParallelMapFusion.apply_to( sdfg, - map_entry1=map1_entry, - map_entry2=map2_entry, + map_entry_1=map1_entry, + map_entry_2=map2_entry, verify=True, ) assert count_node(sdfg, dace.sdfg.nodes.MapEntry) == 1 From 7ccdd9c587a2c5f6d4b7ba41d06b0c0dcbcd761d Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 6 Sep 2024 14:25:48 +0200 Subject: [PATCH 067/115] Made a rename. --- dace/transformation/dataflow/__init__.py | 2 +- dace/transformation/dataflow/map_fusion.py | 6 +++--- dace/transformation/dataflow/map_fusion_parallel.py | 6 +++--- dace/transformation/dataflow/map_fusion_serial.py | 4 ++-- tests/transformations/mapfusion_test.py | 8 ++++---- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/dace/transformation/dataflow/__init__.py b/dace/transformation/dataflow/__init__.py index dbd3838d9f..e7e20c69d1 100644 --- a/dace/transformation/dataflow/__init__.py +++ b/dace/transformation/dataflow/__init__.py @@ -8,7 +8,7 @@ from .map_for_loop import MapToForLoop, MapToForLoopRegion from .map_interchange import MapInterchange from .map_dim_shuffle import MapDimShuffle -from .map_fusion import MapFusion, ParallelMapFusion, SerialMapFusion +from .map_fusion import MapFusion, MapFusionParallel, MapFusionSerial from .map_fission import MapFission from .map_unroll import MapUnroll from .trivial_map_elimination import TrivialMapElimination diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index 3735d3e7dc..e737fab91c 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -1,8 +1,8 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. """Make all map fusion transformations available.""" -from .map_fusion_serial import SerialMapFusion -from .map_fusion_parallel import ParallelMapFusion +from .map_fusion_serial import MapFusionSerial +from .map_fusion_parallel import MapFusionParallel # Compatibility with previous versions of DaCe and clients. -MapFusion = SerialMapFusion +MapFusion = MapFusionSerial diff --git a/dace/transformation/dataflow/map_fusion_parallel.py b/dace/transformation/dataflow/map_fusion_parallel.py index f78bcd0b76..0b8102860a 100644 --- a/dace/transformation/dataflow/map_fusion_parallel.py +++ b/dace/transformation/dataflow/map_fusion_parallel.py @@ -11,10 +11,10 @@ from dace.transformation.dataflow import map_fusion_helper @properties.make_properties -class ParallelMapFusion(map_fusion_helper.MapFusionHelper): - """The `ParallelMapFusion` transformation allows to merge two parallel maps. +class MapFusionParallel(map_fusion_helper.MapFusionHelper): + """The `MapFusionParallel` transformation allows to merge two parallel maps. - While the `SerialMapFusion` transformation fuses maps that are sequentially + While the `MapFusionSerial` transformation fuses maps that are sequentially connected through an intermediate node, this transformation is able to fuse any two maps that are not sequential and in the same scope. diff --git a/dace/transformation/dataflow/map_fusion_serial.py b/dace/transformation/dataflow/map_fusion_serial.py index 56175879c6..f56895b90a 100644 --- a/dace/transformation/dataflow/map_fusion_serial.py +++ b/dace/transformation/dataflow/map_fusion_serial.py @@ -13,7 +13,7 @@ @properties.make_properties -class SerialMapFusion(mfh.MapFusionHelper): +class MapFusionSerial(mfh.MapFusionHelper): """Fuse two serial maps together. The transformation combines two maps into one that are connected through some @@ -36,7 +36,7 @@ class SerialMapFusion(mfh.MapFusionHelper): - This transformation modifies more nodes than it matches. - After the transformation has been applied simplify should be run to remove some dead data flow, that was introduced to ensure validity. - - A `SerialMapFusion` obejct can be initialized and be reused. However, + - A `MapFusionSerial` obejct can be initialized and be reused. However, after new access nodes are added to any state, it is no longer valid to use the object. diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index d828da1242..476a77a7ba 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -8,7 +8,7 @@ from dace import SDFG, SDFGState from dace.sdfg import nodes -from dace.transformation.dataflow import SerialMapFusion, ParallelMapFusion +from dace.transformation.dataflow import MapFusionSerial, MapFusionParallel def count_node(sdfg: SDFG, node_type): @@ -33,7 +33,7 @@ def apply_fusion( """ num_maps_before = count_node(sdfg, nodes.MapEntry) org_sdfg = copy.deepcopy(sdfg) - sdfg.apply_transformations_repeated(SerialMapFusion, validate=True, validate_all=True) + sdfg.apply_transformations_repeated(MapFusionSerial, validate=True, validate_all=True) num_maps_after = count_node(sdfg, nodes.MapEntry) has_processed = False @@ -496,7 +496,7 @@ def test_interstate_fusion(): ref_C = A + 30 ref_D = A + 26 - assert sdfg.apply_transformations_repeated(SerialMapFusion, validate=True, validate_all=True) == 1 + assert sdfg.apply_transformations_repeated(MapFusionSerial, validate=True, validate_all=True) == 1 assert sdfg.number_of_nodes() == 2 assert len([node for node in state1.data_nodes() if node.data == "B"]) == 1 @@ -566,7 +566,7 @@ def _make_sdfg(): if mode: map1_entry, map2_entry = map2_entry, map1_entry - ParallelMapFusion.apply_to( + MapFusionParallel.apply_to( sdfg, map_entry_1=map1_entry, map_entry_2=map2_entry, From d4041b71617acbc713b0a211591e7aa4ed758aa2 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 6 Sep 2024 14:29:29 +0200 Subject: [PATCH 068/115] Applied the formating. --- .../dataflow/map_fusion_helper.py | 95 +++++++------------ .../dataflow/map_fusion_parallel.py | 25 ++--- .../dataflow/map_fusion_serial.py | 78 ++++++--------- 3 files changed, 73 insertions(+), 125 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index 4cb1f4ca4a..cfdbeb3560 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -1,5 +1,4 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. - """Implements Helper functionaliyies for map fusion""" import functools @@ -50,17 +49,16 @@ class MapFusionHelper(transformation.SingleStateTransformation): ) shared_data = properties.DictProperty( key_type=SDFG, - value_type=set, #[str] + value_type=set, #[str] default=None, allow_none=True, - optional=True, # Do not serialize. + optional=True, # Do not serialize. optional_condition=lambda _: False, desc="Maps SDFGs to the set of data that can not be removed," " because they transmit data _between states_, such data will be made 'shared'." " This variable acts as a cache, and is managed by 'is_shared_data()'.", ) - def __init__( self, only_inner_maps: Optional[bool] = None, @@ -77,12 +75,10 @@ def __init__( self.strict_dataflow = bool(strict_dataflow) self.shared_data = {} - @classmethod def expressions(cls) -> bool: raise RuntimeError("The `MapFusionHelper` is not a transformation on its own.") - def can_be_fused( self, map_entry_1: nodes.MapEntry, @@ -130,7 +126,6 @@ def can_be_fused( return True - def relocate_nodes( self, from_node: Union[nodes.MapExit, nodes.MapEntry], @@ -181,15 +176,12 @@ def relocate_nodes( # TODO(phimuell): Check if the symbol is really unused in the target scope. if dmr_symbol in to_node.in_connectors: - raise NotImplementedError( - f"Tried to move the dynamic map range '{dmr_symbol}' from {from_node}'" - f" to '{to_node}', but the symbol is already known there, but the" - " renaming is not implemented." - ) + raise NotImplementedError(f"Tried to move the dynamic map range '{dmr_symbol}' from {from_node}'" + f" to '{to_node}', but the symbol is already known there, but the" + " renaming is not implemented.") if not to_node.add_in_connector(dmr_symbol, force=False): raise RuntimeError( # Might fail because of out connectors. - f"Failed to add the dynamic map range symbol '{dmr_symbol}' to '{to_node}'." - ) + f"Failed to add the dynamic map range symbol '{dmr_symbol}' to '{to_node}'.") helpers.redirect_edge(state=state, edge=edge_to_move, new_dst=to_node) from_node.remove_in_connector(dmr_symbol) @@ -203,9 +195,7 @@ def relocate_nodes( helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) to_node.add_out_connector("OUT_" + new_conn) for e in list(state.out_edges_by_connector(from_node, "OUT_" + old_conn)): - helpers.redirect_edge( - state, e, new_src=to_node, new_src_conn="OUT_" + new_conn - ) + helpers.redirect_edge(state, e, new_src=to_node, new_src_conn="OUT_" + new_conn) from_node.remove_in_connector("IN_" + old_conn) from_node.remove_out_connector("OUT_" + old_conn) @@ -225,12 +215,7 @@ def relocate_nodes( assert len(from_node.in_connectors) == 0 assert len(from_node.out_connectors) == 0 - - def find_parameter_remapping( - self, - first_map: nodes.Map, - second_map: nodes.Map - ) -> Union[Dict[str, str], None]: + def find_parameter_remapping(self, first_map: nodes.Map, second_map: nodes.Map) -> Union[Dict[str, str], None]: """Computes the parameter remapping for the parameters of the _second_ map. The returned `dict` maps the parameters of the second map (keys) to parameter @@ -258,12 +243,12 @@ def find_parameter_remapping( # The ranges, however, we apply some post processing to them. simp = lambda e: symbolic.simplify_ext(symbolic.simplify(e)) first_rngs: Dict[str, Tuple[Any, Any, Any]] = { - param: tuple(simp(r) for r in rng) - for param, rng in zip(first_params, first_map.range) + param: tuple(simp(r) for r in rng) + for param, rng in zip(first_params, first_map.range) } second_rngs: Dict[Tuple[Any, Any, Any]] = { - param: tuple(simp(r) for r in rng) - for param, rng in zip(second_params, second_map.range) + param: tuple(simp(r) for r in rng) + for param, rng in zip(second_params, second_map.range) } # Parameters of the second map that have not yet been matched to a parameter @@ -312,7 +297,6 @@ def find_parameter_remapping( assert len(final_mapping) == len(unmapped_second_params) return final_mapping - def rename_map_parameters( self, first_map: nodes.Map, @@ -343,8 +327,8 @@ def rename_map_parameters( second_map_scope = state.scope_subgraph(entry_node=second_map_entry) # Why is this thing is symbolic and not in replace? symbolic.safe_replace( - mapping=repl_dict, - replace_callback=second_map_scope.replace_dict, + mapping=repl_dict, + replace_callback=second_map_scope.replace_dict, ) # For some odd reason the replace function does not modify the range and @@ -352,7 +336,6 @@ def rename_map_parameters( second_map.params = copy.deepcopy(first_map.params) second_map.range = copy.deepcopy(first_map.range) - def is_shared_data( self, data: nodes.AccessNode, @@ -380,7 +363,6 @@ def is_shared_data( self._compute_shared_data(sdfg) return data.data in self.shared_data[sdfg] - def _compute_shared_data( self, sdfg: dace.SDFG, @@ -427,7 +409,7 @@ def _compute_shared_data( # will get rid of them. shared_data.add(access_node.data) - elif state.out_degree(access_node) != 1: # state.out_degree() == 0 or state.out_degree() > 1 + elif state.out_degree(access_node) != 1: # state.out_degree() == 0 or state.out_degree() > 1 # The access node is either a source node (it is shared in another # state) or the node has a degree larger than one, so it is used # in this state somewhere else. @@ -456,7 +438,6 @@ def _compute_shared_data( # Update the internal cache self.shared_data[sdfg] = shared_data - def _compute_multi_write_data( self, state: SDFGState, @@ -493,7 +474,6 @@ def _compute_multi_write_data( data_written_to.add(access_node.data) return multi_write_data - def is_node_reachable_from( self, graph: Union[dace.SDFG, dace.SDFGState], @@ -511,6 +491,7 @@ def is_node_reachable_from( begin: The start of the DFS. end: The node that should be located. """ + def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: return (edge.dst for edge in graph.out_edges(node)) @@ -528,11 +509,10 @@ def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: # We never found `end` return False - def get_access_set( - self, - scope_node: Union[nodes.MapEntry, nodes.MapExit], - state: SDFGState, + self, + scope_node: Union[nodes.MapEntry, nodes.MapExit], + state: SDFGState, ) -> Set[nodes.AccessNode]: """Computes the access set of a "scope node". @@ -554,22 +534,19 @@ def get_access_set( get_edges = lambda node: state.out_edges(node) other_node = lambda e: e.dst access_set: Set[nodes.AccessNode] = { - node - for node in map(other_node, get_edges(scope_node)) - if isinstance(node, nodes.AccessNode) + node + for node in map(other_node, get_edges(scope_node)) if isinstance(node, nodes.AccessNode) } - return access_set - def find_subsets( - self, - node: nodes.AccessNode, - scope_node: Union[nodes.MapExit, nodes.MapEntry], - state: SDFGState, - sdfg: SDFG, - repl_dict: Optional[Dict[str, str]], + self, + node: nodes.AccessNode, + scope_node: Union[nodes.MapExit, nodes.MapEntry], + state: SDFGState, + sdfg: SDFG, + repl_dict: Optional[Dict[str, str]], ) -> List[subsets.Subset]: """Finds all subsets that access `node` within `scope_node`. @@ -607,24 +584,22 @@ def find_subsets( # Replace happens in place symbolic.safe_replace(repl_dict, subset.replace) - return found_subsets - + return found_subsets def is_view( - self, - node: nodes.AccessNode, - sdfg: SDFG, + self, + node: nodes.AccessNode, + sdfg: SDFG, ) -> bool: """Tests if `node` points to a view or not.""" node_desc: data.Data = node.desc(sdfg) return isinstance(node_desc, data.View) - def track_view( - self, - view: nodes.AccessNode, - state: SDFGState, - sdfg: SDFG, + self, + view: nodes.AccessNode, + state: SDFGState, + sdfg: SDFG, ) -> nodes.AccessNode: """Find the original data of a View. diff --git a/dace/transformation/dataflow/map_fusion_parallel.py b/dace/transformation/dataflow/map_fusion_parallel.py index 0b8102860a..63c75f1d56 100644 --- a/dace/transformation/dataflow/map_fusion_parallel.py +++ b/dace/transformation/dataflow/map_fusion_parallel.py @@ -1,5 +1,4 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. - """Implements the parallel map fusing transformation.""" from typing import Any, Optional, Set, Union @@ -10,6 +9,7 @@ from dace.transformation.dataflow import map_fusion_helper + @properties.make_properties class MapFusionParallel(map_fusion_helper.MapFusionHelper): """The `MapFusionParallel` transformation allows to merge two parallel maps. @@ -39,7 +39,6 @@ class MapFusionParallel(map_fusion_helper.MapFusionHelper): desc="Only perform fusing if the Maps share a node as parent.", ) - def __init__( self, only_if_common_ancestor: Optional[bool] = None, @@ -49,7 +48,6 @@ def __init__( self.only_if_common_ancestor = only_if_common_ancestor super().__init__(**kwargs) - @classmethod def expressions(cls) -> Any: # This just matches _any_ two Maps inside a state. @@ -57,7 +55,6 @@ def expressions(cls) -> Any: state.add_nodes_from([cls.map_entry_1, cls.map_entry_2]) return [state] - def can_be_applied( self, graph: Union[SDFGState, SDFG], @@ -75,11 +72,11 @@ def can_be_applied( # Check the structural properties of the maps, this will also ensure that # the two maps are in the same scope and the parameters can be renamed if not self.can_be_fused( - map_entry_1=map_entry_1, - map_entry_2=map_entry_2, - graph=graph, - sdfg=sdfg, - permissive=permissive, + map_entry_1=map_entry_1, + map_entry_2=map_entry_2, + graph=graph, + sdfg=sdfg, + permissive=permissive, ): return False @@ -98,7 +95,6 @@ def can_be_applied( return True - def is_parallel( self, graph: SDFGState, @@ -129,7 +125,6 @@ def is_parallel( return False return True - def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: """Performs the Map fusing. @@ -144,10 +139,10 @@ def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: # Before we do anything we perform the renaming. self.rename_map_parameters( - first_map=map_entry_1.map, - second_map=map_entry_2.map, - second_map_entry=map_entry_2, - state=graph, + first_map=map_entry_1.map, + second_map=map_entry_2.map, + second_map_entry=map_entry_2, + state=graph, ) for to_node, from_node in zip((map_entry_1, map_exit_1), (map_entry_2, map_exit_2)): diff --git a/dace/transformation/dataflow/map_fusion_serial.py b/dace/transformation/dataflow/map_fusion_serial.py index f56895b90a..6c058e8043 100644 --- a/dace/transformation/dataflow/map_fusion_serial.py +++ b/dace/transformation/dataflow/map_fusion_serial.py @@ -1,5 +1,4 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. - """Implements the serial map fusing transformation.""" import copy @@ -57,7 +56,6 @@ def __init__( ) -> None: super().__init__(**kwargs) - @classmethod def expressions(cls) -> Any: """Get the match expression. @@ -70,7 +68,6 @@ def expressions(cls) -> Any: """ return [dace.sdfg.utils.node_path_graph(cls.map_exit_1, cls.intermediate_access_node, cls.map_entry_2)] - def can_be_applied( self, graph: Union[SDFGState, SDFG], @@ -90,9 +87,7 @@ def can_be_applied( map_entry_2: nodes.MapEntry = self.map_entry_2 # This essentially test the structural properties of the two Maps. - if not self.can_be_fused( - map_entry_1=map_entry_1, map_entry_2=map_entry_2, graph=graph, sdfg=sdfg - ): + if not self.can_be_fused(map_entry_1=map_entry_1, map_entry_2=map_entry_2, graph=graph, sdfg=sdfg): return False # Test for read-write conflicts @@ -120,7 +115,6 @@ def can_be_applied( return False return True - def has_read_write_dependency( self, map_entry_1: nodes.MapEntry, @@ -158,10 +152,7 @@ def has_read_write_dependency( access_sets: List[Dict[str, nodes.AccessNode]] = [] for scope_node in [map_entry_1, map_exit_1, map_entry_2, map_exit_2]: access_set: Set[nodes.AccessNode] = self.get_access_set(scope_node, state) - access_sets.append({ - node.data: node - for node in access_set - }) + access_sets.append({node.data: node for node in access_set}) # If two different access nodes of the same scoping node refers to the # same data, then we consider this as a dependency we can not handle. # It is only a problem for the intermediate nodes and might be possible @@ -241,8 +232,7 @@ def has_read_write_dependency( state=state, sdfg=sdfg, repl_dict=None, - ) - ) + )) # While the subsets defining writing are given by the second map's exit # node, there we also have to apply renaming. all_subsets.extend( @@ -252,8 +242,7 @@ def has_read_write_dependency( state=state, sdfg=sdfg, repl_dict=repl_dict, - ) - ) + )) # Now we can test if these subsets are point wise if not self.test_if_subsets_are_point_wise(all_subsets): return True @@ -261,11 +250,7 @@ def has_read_write_dependency( # No read write dependency was found. return False - - def test_if_subsets_are_point_wise( - self, - subsets_to_check: List[subsets.Subset] - ) -> bool: + def test_if_subsets_are_point_wise(self, subsets_to_check: List[subsets.Subset]) -> bool: """Point wise means that they are all the same. If a series of subsets are point wise it means that all Memlets, access @@ -286,7 +271,7 @@ def test_if_subsets_are_point_wise( if isinstance(subset, subsets.Indices): subset = subsets.Range.from_indices(subset) # Do we also need the reverse? See below why. - if any(r != (0, 0, 1) for r in test in subset.offset_new(master_subset,negative=True)): + if any(r != (0, 0, 1) for r in test in subset.offset_new(master_subset, negative=True)): return False else: # The original code used `Range.offset` here, but that one had trouble @@ -304,7 +289,6 @@ def test_if_subsets_are_point_wise( # point wise return True - def partition_first_outputs( self, state: SDFGState, @@ -312,12 +296,12 @@ def partition_first_outputs( map_exit_1: nodes.MapExit, map_entry_2: nodes.MapEntry, ) -> Union[ - Tuple[ - Set[graph.MultiConnectorEdge[dace.Memlet]], - Set[graph.MultiConnectorEdge[dace.Memlet]], - Set[graph.MultiConnectorEdge[dace.Memlet]], - ], - None, + Tuple[ + Set[graph.MultiConnectorEdge[dace.Memlet]], + Set[graph.MultiConnectorEdge[dace.Memlet]], + Set[graph.MultiConnectorEdge[dace.Memlet]], + ], + None, ]: """Partition the output edges of `map_exit_1` for serial map fusion. @@ -365,8 +349,8 @@ def partition_first_outputs( # Compute the renaming that for translating the parameter of the _second_ # map to the ones used by the first map. repl_dict: Dict[str, str] = self.find_parameter_remapping( - first_map=map_exit_1.map, - second_map=map_entry_2.map, + first_map=map_exit_1.map, + second_map=map_entry_2.map, ) assert repl_dict is not None @@ -406,9 +390,9 @@ def partition_first_outputs( # If the second map is not reachable from the intermediate node, then # the output is pure and we can end here. if not self.is_node_reachable_from( - graph=state, - begin=intermediate_node, - end=map_entry_2, + graph=state, + begin=intermediate_node, + end=map_entry_2, ): pure_outputs.add(out_edge) continue @@ -450,7 +434,8 @@ def partition_first_outputs( # TODO(phimuell): Handle this case properly. # To handle this we need to associate a consumer edge (the outgoing edges # of the second map) with exactly one producer. - producer_edges: List[graph.MultiConnectorEdge[dace.Memlet]] = list(state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:])) + producer_edges: List[graph.MultiConnectorEdge[dace.Memlet]] = list( + state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:])) if len(producer_edges) > 1: return None @@ -519,7 +504,8 @@ def partition_first_outputs( # edges that feeds the consumer and define what is read inside the map. # We do not check them, but collect them and inspect them. # NOTE: The subset still uses the old iteration variables. - for inner_consumer_edge in state.out_edges_by_connector(map_entry_2, "OUT_" + intermediate_node_out_edge.dst_conn[3:]): + for inner_consumer_edge in state.out_edges_by_connector( + map_entry_2, "OUT_" + intermediate_node_out_edge.dst_conn[3:]): if inner_consumer_edge.data.src_subset is None: return None if inner_consumer_edge.data.dynamic: @@ -558,7 +544,6 @@ def partition_first_outputs( assert len(processed_inter_nodes) == sum(len(x) for x in [pure_outputs, exclusive_outputs, shared_outputs]) return (pure_outputs, exclusive_outputs, shared_outputs) - def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: """Performs the serial Map fusing. @@ -586,10 +571,10 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non # Before we do anything we perform the renaming. self.rename_map_parameters( - first_map=map_exit_1.map, - second_map=map_entry_2.map, - second_map_entry=map_entry_2, - state=graph, + first_map=map_exit_1.map, + second_map=map_entry_2.map, + second_map_entry=map_entry_2, + state=graph, ) output_partition = self.partition_first_outputs( @@ -647,7 +632,6 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non # Now turn the second output node into the output node of the first Map. map_exit_2.map = map_entry_1.map - def handle_intermediate_set( self, intermediate_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]], @@ -691,9 +675,7 @@ def handle_intermediate_set( # Now we will determine the shape of the new intermediate. This size of # this temporary is given by the Memlet that goes into the first map exit. - pre_exit_edges = list( - state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:]) - ) + pre_exit_edges = list(state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:])) if len(pre_exit_edges) != 1: raise NotImplementedError() pre_exit_edge = pre_exit_edges[0] @@ -705,9 +687,7 @@ def handle_intermediate_set( if not self.strict_dataflow: squeezed_dims: List[int] = [] # These are the dimensions we removed. new_inter_shape: List[int] = [] # This is the final shape of the new intermediate. - for dim, (proposed_dim_size, full_dim_size) in enumerate( - zip(new_inter_shape_raw, inter_shape) - ): + for dim, (proposed_dim_size, full_dim_size) in enumerate(zip(new_inter_shape_raw, inter_shape)): if full_dim_size == 1: # Must be kept! new_inter_shape.append(proposed_dim_size) elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. @@ -737,9 +717,7 @@ def handle_intermediate_set( ) else: - assert (pre_exit_edge.data.subset.num_elements() > 1) or all( - x == 1 for x in new_inter_shape - ) + assert (pre_exit_edge.data.subset.num_elements() > 1) or all(x == 1 for x in new_inter_shape) is_scalar = False new_inter_name, new_inter_desc = sdfg.add_transient( new_inter_name, From a023f7cafb4372a88f73d5cf01441dec7d725a79 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 9 Sep 2024 08:00:47 +0200 Subject: [PATCH 069/115] Updated the comment about the wrong filter check in `SDFGState._read_and_write_sets()`. --- dace/sdfg/state.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 8c8c9ab15c..21a0d45de3 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -794,13 +794,9 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, for in_edge in in_edges: if out_edge.data.data != in_edge.data.data: # NOTE: This check does not make any sense, and is in my view wrong. - # If we consider a memlet between two access nodes, to which access - # node the `data` attribute of the memlet refers to is arbitrary and - # does not matter. However, the test will filter _some_ out but not - # all. See also the tests inside `tests/sdfg/state_test.py` for the - # wrong behaviour this check induces. - # This check is is retained for compatibility with `RefineNestedAccess`, - # see `tests/numpy/ufunc_support_test.py::test_ufunc_add_accumulate_simple`. + # As it will filter out some accesses but not all, which one solely + # depends on how the memelts were created. + # See also [issue #1634](https://github.com/spcl/dace/issues/1643). continue if in_subsets[in_edge].covers(out_subsets[out_edge]): out_edges.remove(out_edge) From 5c49eee663a6ba9d3e64775d93c56cac86d0f683 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 9 Sep 2024 08:05:18 +0200 Subject: [PATCH 070/115] Removed the wrong check in `SDFGState._read_and_write_sets()`, see also Issue #1634 for more details. As it is written in the issue, I can not simply remove the check but I also have to adapte the tests. The main important one is `tests/transformations/move_loop_into_map_test.py::MoveLoopIntoMapTest::test_more_than_a_map` where the behaviour has changed. However, after carefull examination I am sure that the test is still correct, or better now works correct as there is no dependency. --- dace/sdfg/state.py | 6 ------ tests/sdfg/state_test.py | 3 --- tests/transformations/move_loop_into_map_test.py | 8 ++++++-- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 21a0d45de3..ab49a9285f 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -792,12 +792,6 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, # Filter out memlets which go out but the same data is written to the AccessNode by another memlet for out_edge in list(out_edges): for in_edge in in_edges: - if out_edge.data.data != in_edge.data.data: - # NOTE: This check does not make any sense, and is in my view wrong. - # As it will filter out some accesses but not all, which one solely - # depends on how the memelts were created. - # See also [issue #1634](https://github.com/spcl/dace/issues/1643). - continue if in_subsets[in_edge].covers(out_subsets[out_edge]): out_edges.remove(out_edge) break diff --git a/tests/sdfg/state_test.py b/tests/sdfg/state_test.py index 198ccc4ecf..3bf28bfb51 100644 --- a/tests/sdfg/state_test.py +++ b/tests/sdfg/state_test.py @@ -89,9 +89,6 @@ def test_read_and_write_set_filter(): expected_reads = { "A": [sbs.Range.from_string("0, 0")], - # See comment in `state._read_and_write_sets()` why "B" is here - # it should actually not, but it is a bug. - "B": [sbs.Range.from_string("0")], } expected_writes = { # However, this should always be here. diff --git a/tests/transformations/move_loop_into_map_test.py b/tests/transformations/move_loop_into_map_test.py index dca775bb7a..c06661acc0 100644 --- a/tests/transformations/move_loop_into_map_test.py +++ b/tests/transformations/move_loop_into_map_test.py @@ -147,7 +147,11 @@ def test_apply_multiple_times_1(self): self.assertTrue(np.allclose(val, ref)) def test_more_than_a_map(self): - """ `out` is read and written indirectly by the MapExit, potentially leading to a RW dependency. """ + """`out` is read and written indirectly by the MapExit, potentially leading to a RW dependency. + + Note: + However, there is no write conflict and the transformation can be applied. + """ sdfg = dace.SDFG('more_than_a_map') _, aarr = sdfg.add_array('A', (3, 3), dace.float64) _, barr = sdfg.add_array('B', (3, 3), dace.float64) @@ -171,7 +175,7 @@ def test_more_than_a_map(self): body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) sdfg.add_loop(None, body, None, '_', '0', '_ < 10', '_ + 1') count = sdfg.apply_transformations(MoveLoopIntoMap) - self.assertFalse(count > 0) + self.assertTrue(count > 0) def test_more_than_a_map_1(self): """ From 63e78c9004d0ce97c3aaf9644aaf6c1033b654f6 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 9 Sep 2024 09:02:12 +0200 Subject: [PATCH 071/115] Had to reenable the check in `SDFGState._read_and_write_sets()` is disabled in 5c49eee66. The reason is because `tests/numpy/ufunc_support_test.py::test_ufunc_add_accumulate_simple` fails (in auto optimizer mode). I remember that now. Also the issue is 1643. --- dace/sdfg/state.py | 6 ++++++ tests/sdfg/state_test.py | 3 +++ tests/transformations/move_loop_into_map_test.py | 8 ++------ 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index ab49a9285f..9c8994e56b 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -792,6 +792,12 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, # Filter out memlets which go out but the same data is written to the AccessNode by another memlet for out_edge in list(out_edges): for in_edge in in_edges: + if out_edge.data.data != in_edge.data.data: + # NOTE: This check does not make any sense, and is in my view wrong. + # As it will filter out some accesses but not all, which one solely + # depends on how the memelts were created. + # See also [issue #1643](https://github.com/spcl/dace/issues/1643). + continue if in_subsets[in_edge].covers(out_subsets[out_edge]): out_edges.remove(out_edge) break diff --git a/tests/sdfg/state_test.py b/tests/sdfg/state_test.py index 3bf28bfb51..198ccc4ecf 100644 --- a/tests/sdfg/state_test.py +++ b/tests/sdfg/state_test.py @@ -89,6 +89,9 @@ def test_read_and_write_set_filter(): expected_reads = { "A": [sbs.Range.from_string("0, 0")], + # See comment in `state._read_and_write_sets()` why "B" is here + # it should actually not, but it is a bug. + "B": [sbs.Range.from_string("0")], } expected_writes = { # However, this should always be here. diff --git a/tests/transformations/move_loop_into_map_test.py b/tests/transformations/move_loop_into_map_test.py index c06661acc0..dca775bb7a 100644 --- a/tests/transformations/move_loop_into_map_test.py +++ b/tests/transformations/move_loop_into_map_test.py @@ -147,11 +147,7 @@ def test_apply_multiple_times_1(self): self.assertTrue(np.allclose(val, ref)) def test_more_than_a_map(self): - """`out` is read and written indirectly by the MapExit, potentially leading to a RW dependency. - - Note: - However, there is no write conflict and the transformation can be applied. - """ + """ `out` is read and written indirectly by the MapExit, potentially leading to a RW dependency. """ sdfg = dace.SDFG('more_than_a_map') _, aarr = sdfg.add_array('A', (3, 3), dace.float64) _, barr = sdfg.add_array('B', (3, 3), dace.float64) @@ -175,7 +171,7 @@ def test_more_than_a_map(self): body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) sdfg.add_loop(None, body, None, '_', '0', '_ < 10', '_ + 1') count = sdfg.apply_transformations(MoveLoopIntoMap) - self.assertTrue(count > 0) + self.assertFalse(count > 0) def test_more_than_a_map_1(self): """ From 896ac68308ed306ea133d1685a19f03673b72122 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 11 Sep 2024 08:48:28 +0200 Subject: [PATCH 072/115] Modified the `shared_data` attribute of teh `MapFusionHelper`. Before it was a DaCe Property, but I relaized now that it should actually be a plain data member. This also solves lots of issues I had with serialization. --- .../dataflow/map_fusion_helper.py | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index cfdbeb3560..2adf1f6ae5 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -47,17 +47,10 @@ class MapFusionHelper(transformation.SingleStateTransformation): default=False, desc="If `True` then the transformation will ensure a more stricter data flow.", ) - shared_data = properties.DictProperty( - key_type=SDFG, - value_type=set, #[str] - default=None, - allow_none=True, - optional=True, # Do not serialize. - optional_condition=lambda _: False, - desc="Maps SDFGs to the set of data that can not be removed," - " because they transmit data _between states_, such data will be made 'shared'." - " This variable acts as a cache, and is managed by 'is_shared_data()'.", - ) + # Maps SDFGs to the set of data that can not be removed, + # because they transmit data _between states_, such data will be made 'shared'. + # This variable acts as a cache, and is managed by 'is_shared_data()'. + _shared_data: Dict[SDFG, Set[str]] def __init__( self, @@ -73,7 +66,7 @@ def __init__( self.only_inner_maps = bool(only_inner_maps) if strict_dataflow is not None: self.strict_dataflow = bool(strict_dataflow) - self.shared_data = {} + self._shared_data = {} @classmethod def expressions(cls) -> bool: @@ -359,9 +352,9 @@ def is_shared_data( There is no mechanism to detect if the cache must be evicted. However, as long as no additional data is added, there is no problem. """ - if sdfg not in self.shared_data: + if sdfg not in self._shared_data: self._compute_shared_data(sdfg) - return data.data in self.shared_data[sdfg] + return data.data in self._shared_data[sdfg] def _compute_shared_data( self, @@ -436,7 +429,7 @@ def _compute_shared_data( shared_data.update(interstate_read_symbols.intersection(prevously_seen_data)) # Update the internal cache - self.shared_data[sdfg] = shared_data + self._shared_data[sdfg] = shared_data def _compute_multi_write_data( self, From fcffb220433b0c1aa28b520b4fc945b88b7e1123 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 12 Sep 2024 08:28:35 +0200 Subject: [PATCH 073/115] This compute offset function seems to solve all my problems. I will now add a very complicated test to ensure that it realy does what I want. --- .../dataflow/map_fusion_helper.py | 1 + .../dataflow/map_fusion_serial.py | 74 +++++++++++++++++-- 2 files changed, 70 insertions(+), 5 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index 2adf1f6ae5..a4ce5de01a 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -133,6 +133,7 @@ def relocate_nodes( once for the entry and then for the exit. While it does not remove the node themselves if guarantees that the `from_node` has degree zero. + The function assumes that the parameter renaming was already done. Args: from_node: Node from which the edges should be removed. diff --git a/dace/transformation/dataflow/map_fusion_serial.py b/dace/transformation/dataflow/map_fusion_serial.py index 6c058e8043..60701540a2 100644 --- a/dace/transformation/dataflow/map_fusion_serial.py +++ b/dace/transformation/dataflow/map_fusion_serial.py @@ -289,6 +289,58 @@ def test_if_subsets_are_point_wise(self, subsets_to_check: List[subsets.Subset]) # point wise return True + def compute_offset_subset( + self, + original_subset: subsets.Range, + intermediate_desc: data.Data, + map_params: List[str], + producer_offset: Optional[subsets.Range] = None, + ) -> subsets.Range: + """Computes the memlet to correct read and writes of the intermediate. + + Args: + original_subset: The original subset that was used to write into the + intermediate, must be renamed to the final map parameter. + intermediate_desc: The original intermediate data descriptor. + map_params: The parameter of the final map. + """ + assert not isinstance(intermediate_desc, data.View) + final_offset: subsets.Range = None + if isinstance(intermediate_desc, data.Scalar): + final_offset = subsets.Range.from_string("0") + + elif isinstance(intermediate_desc, data.Array): + basic_offsets = original_subset.min_element() + offset_list = [] + for d in range(original_subset.dims()): + d_range = subsets.Range([original_subset[d]]) + if d_range.free_symbols.intersection(map_params): + offset_list.append(d_range[0]) + else: + offset_list.append((basic_offsets[d], basic_offsets[d], 1)) + final_offset = subsets.Range(offset_list) + + else: + raise TypeError(f"Does not know how to compute the subset offset for '{type(intermediate_desc).__name__}'.") + + if producer_offset is not None: + # Here we are correcting some parts that over approximate (which partially + # does under approximate) might screw up. Consider two maps, the first + # map only writes the subset `[:, 2:6]`, thus the new intermediate will + # have shape `(1, 4)`. Now also imagine that the second map only reads + # the elements `[:, 3]`. From this we see that we can only correct the + # consumer side if we also take the producer side into consideration! + # See also the `transformations/mapfusion_test.py::test_offset_correction_*` + # tests for more. + final_offset.offset( + final_offset.offset_new( + producer_offset, + negative=True, + ), + negative=True, + ) + return final_offset + def partition_first_outputs( self, state: SDFGState, @@ -648,6 +700,7 @@ def handle_intermediate_set( output set, see `partition_first_outputs()`. The main difference is that in exclusive mode the intermediate nodes will be fully removed from the SDFG. While in shared mode the intermediate node will be preserved. + The function assumes that the parameter renaming was already done. Args: intermediate_outputs: The set of outputs, that should be processed. @@ -663,6 +716,8 @@ def handle_intermediate_set( after this function has run the state is (most likely) invalid. """ + map_params = map_exit_1.map.params.copy() + # Now we will iterate over all intermediate edges and process them. # If not stated otherwise the comments assume that we run in exclusive mode. for out_edge in intermediate_outputs: @@ -732,7 +787,11 @@ def handle_intermediate_set( # Memlets, since they now write into the new (smaller) intermediate. assert pre_exit_edge.data.data == inter_name assert pre_exit_edge.data.dst_subset is not None - old_pre_exit_edge_subset = pre_exit_edge.data.dst_subset + producer_offset = self.compute_offset_subset( + original_subset=pre_exit_edge.data.dst_subset, + intermediate_desc=inter_desc, + map_params=map_params, + ) # Memlets have a lot of additional informations, such as dynamic. # To ensure that we get all of them, we will now copy them and modify @@ -769,7 +828,7 @@ def handle_intermediate_set( # Since we now write into a smaller memory patch, we must # compensate for that. We do this by substracting where the write # originally had begun. - producer_edge.data.dst_subset.offset(old_pre_exit_edge_subset, negative=True) + producer_edge.data.dst_subset.offset(producer_offset, negative=True) producer_edge.data.dst_subset.pop(squeezed_dims) # Now after we have handled the input of the new intermediate node, @@ -801,7 +860,12 @@ def handle_intermediate_set( # As for the producer side, we now read from a smaller array, # So we must offset them, we use the original edge for this. assert inner_edge.data.src_subset is not None - inner_edge_correction_offset = copy.deepcopy(inner_edge.data.src_subset) + consumer_offset = self.compute_offset_subset( + original_subset=inner_edge.data.src_subset, + intermediate_desc=inter_desc, + map_params=map_params, + producer_offset=producer_offset, + ) # Now we create a new connection that instead reads from the new # intermediate, instead of the old one. For this we use the @@ -824,7 +888,7 @@ def handle_intermediate_set( if is_scalar: new_inner_memlet.subset = "0" elif new_inner_memlet.src_subset is not None: - new_inner_memlet.src_subset.offset(inner_edge_correction_offset, negative=True) + new_inner_memlet.src_subset.offset(consumer_offset, negative=True) new_inner_memlet.src_subset.pop(squeezed_dims) # Now we have to make sure that all consumers are properly updated. @@ -836,7 +900,7 @@ def handle_intermediate_set( if is_scalar: consumer_edge.data.src_subset = "0" elif consumer_edge.data.src_subset is not None: - consumer_edge.data.src_subset.offset(inner_edge_correction_offset, negative=True) + consumer_edge.data.src_subset.offset(consumer_offset, negative=True) consumer_edge.data.src_subset.pop(squeezed_dims) # The edge that leaves the second map entry was already deleted. We now delete From 0ddb3c2f84b6693b5ce4d1b7617e6eafe5557c17 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 12 Sep 2024 11:15:26 +0200 Subject: [PATCH 074/115] Added a test for the special case. Let's see if the CI can handle it. --- tests/transformations/mapfusion_test.py | 107 +++++++++++++++++++++++- 1 file changed, 106 insertions(+), 1 deletion(-) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 476a77a7ba..0702f2dfb7 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -8,7 +8,7 @@ from dace import SDFG, SDFGState from dace.sdfg import nodes -from dace.transformation.dataflow import MapFusionSerial, MapFusionParallel +from dace.transformation.dataflow import MapFusionSerial, MapFusionParallel, MapExpansion def count_node(sdfg: SDFG, node_type): @@ -628,6 +628,108 @@ def inner_product( assert all(in_edge.src is last_map_entry for in_edge in state.in_edges(indirect_access_tasklet)) +def make_correction_offset_sdfg( + range_read: bool, + second_read_start: int, +) -> SDFG: + """Make the SDFGs for the `test_offset_correction_*` tests. + + Args: + range_read: If `True` then a range is read in the second map. + if `False` then only a scalar is read. + second_read_start: Where the second map should start reading. + """ + sdfg = SDFG("offset_correction_test") + state = sdfg.add_state(is_start_block=True) + shapes = { + "A": (20, 10), + "B": (20, 8), + "C": (20, 2) if range_read else (20, 1), + } + descs = {} + for name, shape in shapes.items(): + _, desc = sdfg.add_array(name, shape, dace.float64, transient=False) + descs[name] = desc + sdfg.arrays["B"].transient = True + A, B, C = (state.add_access(name) for name in sorted(shapes.keys())) + + state.add_mapped_tasklet( + "first_map", + map_ranges={"i": "0:20", "j": "2:8"}, + inputs={"__in1": dace.Memlet("A[i, j]")}, + code="__out = __in1 + 1.0", + outputs={"__out": dace.Memlet("B[i, j]")}, + input_nodes={"A": A}, + output_nodes={"B": B}, + external_edges=True, + ) + state.add_mapped_tasklet( + "second_map", + map_ranges=( + {"i": "0:20", "k": "0:2"} + if range_read + else {"i": "0:20"} + ), + inputs={"__in1": dace.Memlet(f"B[i, {second_read_start}{'+k' if range_read else ''}]")}, + code="__out = __in1", + outputs={"__out": dace.Memlet(f"C[i, {'k' if range_read else '0'}]")}, + input_nodes={"B": B}, + output_nodes={"C": C}, + external_edges=True, + ) + sdfg.validate() + assert sdfg.apply_transformations_repeated(MapExpansion, validate_all=True) > 0 + return sdfg + + +def test_offset_correction_range_read(): + + np.random.seed(42) + A = np.random.rand(20, 10) + C = np.zeros((20, 2)) + exp = (A + 1.0)[:, 3:5].copy() + + sdfg = make_correction_offset_sdfg(range_read=True, second_read_start=3) + + sdfg(A=A, C=C) + assert np.allclose(C, exp) + C[:] = 0.0 + + apply_fusion(sdfg) + + sdfg(A=A, C=C) + assert np.allclose(C, exp) + + +def test_offset_correction_scalar_read(): + + np.random.seed(42) + A = np.random.rand(20, 10) + C = np.zeros((20, 1)) + exp = (A + 1.0)[:, 3].copy().reshape((-1, 1)) + + sdfg = make_correction_offset_sdfg(range_read=False, second_read_start=3) + + sdfg(A=A, C=C) + assert np.allclose(C, exp) + C[:] = 0.0 + + apply_fusion(sdfg) + + sdfg(A=A, C=C) + assert np.allclose(C, exp) + + +def test_offset_correction_empty(): + + # Because the second map starts reading from 1, but the second map only + # starts writing from 2 there is no overlap and it can not be fused. + # NOTE: This computation is useless. + sdfg = make_correction_offset_sdfg(range_read=True, second_read_start=1) + + apply_fusion(sdfg, removed_maps=0) + + if __name__ == '__main__': test_indirect_accesses() test_fusion_shared() @@ -644,5 +746,8 @@ def inner_product( test_fusion_with_nested_sdfg_1() test_parallel_fusion_simple() test_fuse_indirect_accesses() + test_offset_correction_range_read() + test_offset_correction_scalar_read() + test_offset_correction_empty() print("SUCCESS") From 05ffee45f585bfd2fd94224fe2688f16b57f3ee1 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 12 Sep 2024 14:08:14 +0200 Subject: [PATCH 075/115] Did some cleanup. --- .../dataflow/map_fusion_helper.py | 21 ++++++++----------- .../dataflow/map_fusion_parallel.py | 4 ++-- .../dataflow/map_fusion_serial.py | 16 +++++--------- 3 files changed, 16 insertions(+), 25 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py index a4ce5de01a..deadeee5b4 100644 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ b/dace/transformation/dataflow/map_fusion_helper.py @@ -1,15 +1,12 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. """Implements Helper functionaliyies for map fusion""" -import functools -import itertools -import re import copy -from typing import Any, Dict, Iterable, List, Optional, Set, Sequence, Tuple, Union, overload +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import dace -from dace import data, properties, subsets, transformation, symbolic -from dace.sdfg import SDFG, SDFGState, graph, nodes, validation, replace +from dace import data, properties, subsets, symbolic, transformation +from dace.sdfg import SDFG, SDFGState, nodes, validation from dace.transformation import helpers @@ -240,7 +237,7 @@ def find_parameter_remapping(self, first_map: nodes.Map, second_map: nodes.Map) param: tuple(simp(r) for r in rng) for param, rng in zip(first_params, first_map.range) } - second_rngs: Dict[Tuple[Any, Any, Any]] = { + second_rngs: Dict[str, Tuple[Any, Any, Any]] = { param: tuple(simp(r) for r in rng) for param, rng in zip(second_params, second_map.range) } @@ -464,7 +461,7 @@ def _compute_multi_write_data( multi_write_data.add(access_node.data) elif self.is_view(access_node, sdfg): # This is an over approximation. - multi_write_data.update([access_node.data, track_view(access_node, state, sdfg).data]) + multi_write_data.update([access_node.data, self.track_view(access_node, state, sdfg).data]) data_written_to.add(access_node.data) return multi_write_data @@ -489,11 +486,11 @@ def is_node_reachable_from( def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: return (edge.dst for edge in graph.out_edges(node)) - to_visit: List[dace_nodes.Node] = [begin] - seen: Set[dace_nodes.Node] = set() + to_visit: List[nodes.Node] = [begin] + seen: Set[nodes.Node] = set() while len(to_visit) > 0: - node: dace_nodes.Node = to_visit.pop() + node: nodes.Node = to_visit.pop() if node == end: return True elif node not in seen: @@ -569,7 +566,7 @@ def find_subsets( found_subsets: List[subsets.Subset] = [] for edge in outer_edges_to_inspect: found_subsets.extend(get_subset(e) for e in get_inner_edges(edge)) - assert len(found_subsets) > 0, f"Could not find any subsets." + assert len(found_subsets) > 0, "Could not find any subsets." assert not any(subset is None for subset in found_subsets) found_subsets = copy.deepcopy(found_subsets) diff --git a/dace/transformation/dataflow/map_fusion_parallel.py b/dace/transformation/dataflow/map_fusion_parallel.py index 63c75f1d56..41e8e3bd3d 100644 --- a/dace/transformation/dataflow/map_fusion_parallel.py +++ b/dace/transformation/dataflow/map_fusion_parallel.py @@ -7,11 +7,11 @@ from dace import properties, transformation from dace.sdfg import SDFG, SDFGState, graph, nodes -from dace.transformation.dataflow import map_fusion_helper +from dace.transformation.dataflow import map_fusion_helper as mfh @properties.make_properties -class MapFusionParallel(map_fusion_helper.MapFusionHelper): +class MapFusionParallel(mfh.MapFusionHelper): """The `MapFusionParallel` transformation allows to merge two parallel maps. While the `MapFusionSerial` transformation fuses maps that are sequentially diff --git a/dace/transformation/dataflow/map_fusion_serial.py b/dace/transformation/dataflow/map_fusion_serial.py index 60701540a2..613d431d74 100644 --- a/dace/transformation/dataflow/map_fusion_serial.py +++ b/dace/transformation/dataflow/map_fusion_serial.py @@ -2,7 +2,7 @@ """Implements the serial map fusing transformation.""" import copy -from typing import Any, Dict, List, Set, Tuple, Union, Optional +from typing import Any, Dict, List, Optional, Set, Tuple, Union import dace from dace import data, dtypes, properties, subsets, symbolic, transformation @@ -167,13 +167,13 @@ def has_read_write_dependency( resolved_sets: List[Set[str]] = [] for unresolved_set in [read_map_1, write_map_1, read_map_2, write_map_2]: resolved_sets.append({ - self.track_view(node).data if self.is_view(node, sdfg) else node.data + self.track_view(node, state, sdfg).data if self.is_view(node, sdfg) else node.data for node in unresolved_set.values() }) # If the resolved and unresolved names do not have the same length. # Then different views point to the same location, which we forbid if len(unresolved_set) != len(resolved_sets[-1]): - return None + return False real_read_map_1, real_write_map_1, real_read_map_2, real_write_map_2 = resolved_sets # We do not allow that the first and second map each write to the same data. @@ -189,7 +189,7 @@ def has_read_write_dependency( # to transmit information between the maps. The partition function ensures that # these nodes are directly connected to the two maps. exchange_names: Set[str] = set(write_map_1.keys()).intersection(read_map_2.keys()) - exchange_nodes: Set[node.AccessNode] = set(write_map_1.values()).intersection(read_map_2.values()) + exchange_nodes: Set[nodes.AccessNode] = set(write_map_1.values()).intersection(read_map_2.values()) # If the number are different then a data is accessed through multiple nodes. if len(exchange_names) != len(exchange_nodes): @@ -271,7 +271,7 @@ def test_if_subsets_are_point_wise(self, subsets_to_check: List[subsets.Subset]) if isinstance(subset, subsets.Indices): subset = subsets.Range.from_indices(subset) # Do we also need the reverse? See below why. - if any(r != (0, 0, 1) for r in test in subset.offset_new(master_subset, negative=True)): + if any(r != (0, 0, 1) for r in subset.offset_new(master_subset, negative=True)): return False else: # The original code used `Range.offset` here, but that one had trouble @@ -393,11 +393,6 @@ def partition_first_outputs( exclusive_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() shared_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() - # These are the iteration parameters of the two maps. - # They are not yet modified, that they match each other. - map_params_1: Sequence[str] = map_exit_1.map.params - map_params_2: Sequence[str] = map_entry_2.map.params - # Compute the renaming that for translating the parameter of the _second_ # map to the ones used by the first map. repl_dict: Dict[str, str] = self.find_parameter_remapping( @@ -460,7 +455,6 @@ def partition_first_outputs( # handled has shared intermediates. if not isinstance(intermediate_node, nodes.AccessNode): return None - intermediate_desc: data.Data = intermediate_node.desc(sdfg) if self.is_view(intermediate_node, sdfg): return None From 8c86662aa437b2243511762a8a2d9258959d1419 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 13 Sep 2024 08:11:59 +0200 Subject: [PATCH 076/115] Specified how the corrector function of the offsets works. --- .../dataflow/map_fusion_serial.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_serial.py b/dace/transformation/dataflow/map_fusion_serial.py index 613d431d74..f284b4c7c8 100644 --- a/dace/transformation/dataflow/map_fusion_serial.py +++ b/dace/transformation/dataflow/map_fusion_serial.py @@ -294,20 +294,32 @@ def compute_offset_subset( original_subset: subsets.Range, intermediate_desc: data.Data, map_params: List[str], - producer_offset: Optional[subsets.Range] = None, + producer_offset: Union[subsets.Range, None], ) -> subsets.Range: """Computes the memlet to correct read and writes of the intermediate. + This is the value that must be substracted from the memlets to adjust, i.e + (`memlet_to_adjust(correction, negative=True)`). If `producer_offset` is + `None` then the function computes the correction that should be applied to + the producer memlets, i.e. the memlets of the tree converging at + `intermediate_node`. If `producer_offset` is given, it should be the output + of the previous call to this function, with `producer_offset=None`. In this + case the function computes the correction for the consumer side, i.e. the + memlet tree that originates at `intermediate_desc`. + Args: original_subset: The original subset that was used to write into the intermediate, must be renamed to the final map parameter. intermediate_desc: The original intermediate data descriptor. map_params: The parameter of the final map. + producer_offset: The correction that was applied to the producer side. """ assert not isinstance(intermediate_desc, data.View) final_offset: subsets.Range = None if isinstance(intermediate_desc, data.Scalar): - final_offset = subsets.Range.from_string("0") + # If the intermediate was a scalar, then it will remain a scalar. + # Thus there is no correction that we must apply. + return subsets.Range.from_string("0") elif isinstance(intermediate_desc, data.Array): basic_offsets = original_subset.min_element() @@ -785,6 +797,7 @@ def handle_intermediate_set( original_subset=pre_exit_edge.data.dst_subset, intermediate_desc=inter_desc, map_params=map_params, + producer_offset=None, ) # Memlets have a lot of additional informations, such as dynamic. From 11a316744d6c4117d3f6b97b4141afcf651f1cbe Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 26 Sep 2024 08:17:49 +0200 Subject: [PATCH 077/115] UPdated some comments. --- dace/sdfg/state.py | 42 +++++++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index c40c05aa7f..4844720fa1 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -747,6 +747,7 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, Determines what data is read and written in this subgraph, returning dictionaries from data containers to all subsets that are read/written. """ + from dace.sdfg import utils # Avoid cyclic import # Ensures that the `{src,dst}_subset` are properly set. # TODO: find where the problems are @@ -755,23 +756,30 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, read_set = collections.defaultdict(list) write_set = collections.defaultdict(list) - from dace.sdfg import utils # Avoid cyclic import - subgraphs = utils.concurrent_subgraphs(self) - for sg in subgraphs: - rs = collections.defaultdict(list) - ws = collections.defaultdict(list) + + for subgraph in utils.concurrent_subgraphs(self): + subgraph_read_set = collections.defaultdict(list) # read and write set of this subgraph. + subgraph_write_set = collections.defaultdict(list) # Traverse in topological order, so data that is written before it # is read is not counted in the read set # TODO: This only works if every data descriptor is only once in a path. - for n in utils.dfs_topological_sort(sg, sources=sg.source_nodes()): + for n in utils.dfs_topological_sort(subgraph, sources=subgraph.source_nodes()): if not isinstance(n, nd.AccessNode): + # Read and writes can only be done through access nodes, + # so ignore every other node. continue + + # Get a list of all incoming (writes) and outgoing (reads) edges of the + # access node, ignore all empty memlets as they do not carry any data. + in_edges = [in_edge for in_edge in subgraph.in_edges(n) if not in_edge.data.is_empty()] + out_edges = [out_edge for out_edge in subgraph.out_edges(n) if not out_edge.data.is_empty()] + + # Extract the subsets that describes where we read and write the data + # and store them for the later filtering. + # NOTE: In certain cases the corresponding subset might be None, in this case + # we assume that the whole array is written, which is the default behaviour. ac_desc = n.desc(self.sdfg) ac_size = ac_desc.total_size - in_edges = [in_edge for in_edge in sg.in_edges(n) if not in_edge.data.is_empty()] - out_edges = [out_edge for out_edge in sg.out_edges(n) if not out_edge.data.is_empty()] - - # In some conditions subsets can be `None`, we will now clean them. in_subsets = dict() for in_edge in in_edges: assert in_edge.data.dst_subset is not None or (in_edge.data.num_elements() == ac_size) @@ -789,12 +797,12 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, else out_edge.data.src_subset ) - # Filter out memlets which go out but the same data is written to the AccessNode by another memlet + # Filter out reads that are also written at the access node by another (single) write. for out_edge in list(out_edges): for in_edge in in_edges: if out_edge.data.data != in_edge.data.data: - # NOTE: This check does not make any sense, and is in my view wrong. - # As it will filter out some accesses but not all, which one solely + # NOTE: This check does not make any sense, and is in my (@philip-paul-mueller) + # view wrong. As it will filter out some accesses but not all, which one solely # depends on how the memelts were created. # See also [issue #1643](https://github.com/spcl/dace/issues/1643). continue @@ -803,16 +811,16 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, break if in_edges: - ws[n.data].extend(in_subsets.values()) + subgraph_write_set[n.data].extend(in_subsets.values()) if out_edges: - rs[n.data].extend(out_subsets[out_edge] for out_edge in out_edges) + subgraph_read_set[n.data].extend(out_subsets[out_edge] for out_edge in out_edges) # Union all subgraphs, so an array that was excluded from the read # set because it was written first is still included if it is read # in another subgraph - for data, accesses in rs.items(): + for data, accesses in subgraph_read_set.items(): read_set[data] += accesses - for data, accesses in ws.items(): + for data, accesses in subgraph_write_set.items(): write_set[data] += accesses return copy.deepcopy((read_set, write_set)) From 44cf6ad6d8ceb43371dc47877f3c5688738d6857 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 26 Sep 2024 08:30:44 +0200 Subject: [PATCH 078/115] Added more comments. --- dace/sdfg/state.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 4844720fa1..f6adb06801 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -782,12 +782,15 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, ac_size = ac_desc.total_size in_subsets = dict() for in_edge in in_edges: - assert in_edge.data.dst_subset is not None or (in_edge.data.num_elements() == ac_size) + # Ensure that if the destination subset is not given, our assumption, that the + # whole array is written to, is valid, by testing if the memlet transfers the + # whole array. + assert (in_edge.data.dst_subset is not None) or (in_edge.data.num_elements() == ac_size) in_subsets[in_edge] = ( sbs.Range.from_array(ac_desc) if in_edge.data.dst_subset is None else in_edge.data.dst_subset - ) + ) out_subsets = dict() for out_edge in out_edges: assert (out_edge.data.src_subset is not None) or (out_edge.data.num_elements() == ac_size) @@ -797,31 +800,35 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, else out_edge.data.src_subset ) - # Filter out reads that are also written at the access node by another (single) write. + # If a memlet reads a particular region of data from the access node and there + # exists a memlet at the same access node that writes to the same region, then + # this read is ignored, and not included in the final read set, but only + # accounted fro in the write set. See also note below. + # TODO: Handle the case when multiple disjoint writes are needed to cover the read. for out_edge in list(out_edges): for in_edge in in_edges: if out_edge.data.data != in_edge.data.data: # NOTE: This check does not make any sense, and is in my (@philip-paul-mueller) # view wrong. As it will filter out some accesses but not all, which one solely - # depends on how the memelts were created. - # See also [issue #1643](https://github.com/spcl/dace/issues/1643). + # depends on how the memelts were created, i.e. to which container their `data` + # attribute is associated to. See also [issue #1643](https://github.com/spcl/dace/issues/1643). continue if in_subsets[in_edge].covers(out_subsets[out_edge]): out_edges.remove(out_edge) break + # Update the read and write sets of the subgraph. if in_edges: subgraph_write_set[n.data].extend(in_subsets.values()) if out_edges: subgraph_read_set[n.data].extend(out_subsets[out_edge] for out_edge in out_edges) - # Union all subgraphs, so an array that was excluded from the read - # set because it was written first is still included if it is read - # in another subgraph + # Add the subgraph's read and write set to the final ones. for data, accesses in subgraph_read_set.items(): read_set[data] += accesses for data, accesses in subgraph_write_set.items(): write_set[data] += accesses + return copy.deepcopy((read_set, write_set)) def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: From 5e25816a4ac74ae1695d1b5097f1f9d090cbf713 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 1 Nov 2024 09:16:29 +0100 Subject: [PATCH 079/115] Removed the parallel map fusion transformation. Now only the serial version is there. Also integrated the helper into the serial file. --- dace/transformation/dataflow/__init__.py | 2 +- dace/transformation/dataflow/map_fusion.py | 1520 ++++++++++++++++- .../dataflow/map_fusion_helper.py | 632 ------- .../dataflow/map_fusion_parallel.py | 156 -- .../dataflow/map_fusion_serial.py | 966 ----------- 5 files changed, 1516 insertions(+), 1760 deletions(-) delete mode 100644 dace/transformation/dataflow/map_fusion_helper.py delete mode 100644 dace/transformation/dataflow/map_fusion_parallel.py delete mode 100644 dace/transformation/dataflow/map_fusion_serial.py diff --git a/dace/transformation/dataflow/__init__.py b/dace/transformation/dataflow/__init__.py index a336764dc0..6fa274f041 100644 --- a/dace/transformation/dataflow/__init__.py +++ b/dace/transformation/dataflow/__init__.py @@ -8,7 +8,7 @@ from .map_for_loop import MapToForLoop, MapToForLoopRegion from .map_interchange import MapInterchange from .map_dim_shuffle import MapDimShuffle -from .map_fusion import MapFusion, MapFusionParallel, MapFusionSerial +from .map_fusion import MapFusion from .map_fission import MapFission from .map_unroll import MapUnroll from .trivial_map_elimination import TrivialMapElimination diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index e737fab91c..032cf45634 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -1,8 +1,1518 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -"""Make all map fusion transformations available.""" +"""Implements the serial map fusing transformation.""" -from .map_fusion_serial import MapFusionSerial -from .map_fusion_parallel import MapFusionParallel +import copy +from typing import Any, Dict, List, Optional, Set, Tuple, Union -# Compatibility with previous versions of DaCe and clients. -MapFusion = MapFusionSerial +import dace +from dace import data, dtypes, properties, subsets, symbolic, transformation +from dace.sdfg import SDFG, SDFGState, graph, nodes, validation +from dace.transformation import helpers + + +@properties.make_properties +class MapFusion(transformation.SingleStateTransformation): + """Fuse two serial maps together. + + The transformation combines two maps into one that are connected through some + access nodes. Conceptually this transformation removes the exit of the first + or upper map and the entry of the lower or second map and then rewrites the + connections appropriately. Depending on the situation the transformation will + either fully remove or make the intermediate a new output of the second map. + + By default, the transformation does not use the strict data flow mode. However, + it might be useful in come cases to enable it. + + Args: + only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + only_toplevel_maps: Only consider Maps that are at the top. + strict_dataflow: If `True`, the transformation ensures a more + stricter version of the data flow. + + Notes: + - This transformation modifies more nodes than it matches. + - After the transformation has been applied simplify should be run to remove + some dead data flow, that was introduced to ensure validity. + - A `MapFusion` obejct can be initialized and be reused. However, + after new access nodes are added to any state, it is no longer valid + to use the object. + + Todo: + - Consider the case that only shared nodes are created (thus no inspection of + the graph is needed) and make all shared. Then use the dead dataflow + elimination transformation to get rid of the ones we no longer need. + - Increase the applicability. + """ + + # Pattern Nodes + map_exit_1 = transformation.transformation.PatternNode(nodes.MapExit) + intermediate_access_node = transformation.transformation.PatternNode(nodes.AccessNode) + map_entry_2 = transformation.transformation.PatternNode(nodes.MapEntry) + + + # Settings + only_toplevel_maps = properties.Property( + dtype=bool, + default=False, + desc="Only perform fusing if the Maps are in the top level.", + ) + only_inner_maps = properties.Property( + dtype=bool, + default=False, + desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", + ) + strict_dataflow = properties.Property( + dtype=bool, + default=False, + desc="If `True` then the transformation will ensure a more stricter data flow.", + ) + # Maps SDFGs to the set of data that can not be removed, + # because they transmit data _between states_, such data will be made 'shared'. + # This variable acts as a cache, and is managed by 'is_shared_data()'. + _shared_data: Dict[SDFG, Set[str]] + + + def __init__( + self, + only_inner_maps: Optional[bool] = None, + only_toplevel_maps: Optional[bool] = None, + strict_dataflow: Optional[bool] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if only_toplevel_maps is not None: + self.only_toplevel_maps = bool(only_toplevel_maps) + if only_inner_maps is not None: + self.only_inner_maps = bool(only_inner_maps) + if strict_dataflow is not None: + self.strict_dataflow = bool(strict_dataflow) + self._shared_data = {} + + + @classmethod + def expressions(cls) -> Any: + """Get the match expression. + + The transformation matches the exit node of the top Map that is connected to + an access node that again is connected to the entry node of the second Map. + An important note is, that the transformation operates not just on the + matched nodes, but more or less on anything that has an incoming connection + from the first Map or an outgoing connection to the second Map entry. + """ + return [dace.sdfg.utils.node_path_graph(cls.map_exit_1, cls.intermediate_access_node, cls.map_entry_2)] + + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Tests if the matched Maps can be merged. + + The two Maps are mergeable iff: + - Checks general requirements, see `can_topologically_be_fused()`. + - Tests if there are read write dependencies. + - Tests if the decomposition exists. + """ + map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit_1) + map_exit_1: nodes.MapExit = self.map_exit_1 + map_entry_2: nodes.MapEntry = self.map_entry_2 + + # This essentially test the structural properties of the two Maps. + if not self.can_topologically_be_fused(map_entry_1=map_entry_1, map_entry_2=map_entry_2, graph=graph, sdfg=sdfg): + return False + + # Test for read-write conflicts + if self.has_read_write_dependency( + map_entry_1=map_entry_1, + map_entry_2=map_entry_2, + state=graph, + sdfg=sdfg, + ): + return False + + # Two maps can be serially fused if the node decomposition exists and + # at least one of the intermediate output sets is not empty. The state + # of the pure outputs is irrelevant for serial map fusion. + output_partition = self.partition_first_outputs( + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + ) + if output_partition is None: + return False + _, exclusive_outputs, shared_outputs = output_partition + if not (exclusive_outputs or shared_outputs): + return False + return True + + + def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: + """Performs the serial Map fusing. + + The function first computes the map decomposition and then handles the + three sets. The pure outputs are handled by `relocate_nodes()` while + the two intermediate sets are handled by `handle_intermediate_set()`. + + By assumption we do not have to rename anything. + + Args: + graph: The SDFG state we are operating on. + sdfg: The SDFG we are operating on. + """ + # NOTE: `self.map_*` actually stores the ID of the node. + # once we start adding and removing nodes it seems that their ID changes. + # Thus we have to save them here, this is a known behaviour in DaCe. + assert isinstance(graph, dace.SDFGState) + assert isinstance(self.map_exit_1, nodes.MapExit) + assert isinstance(self.map_entry_2, nodes.MapEntry) + + map_exit_1: nodes.MapExit = self.map_exit_1 + map_entry_2: nodes.MapEntry = self.map_entry_2 + map_exit_2: nodes.MapExit = graph.exit_node(self.map_entry_2) + map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit_1) + + # Before we do anything we perform the renaming. + self.rename_map_parameters( + first_map=map_exit_1.map, + second_map=map_entry_2.map, + second_map_entry=map_entry_2, + state=graph, + ) + + output_partition = self.partition_first_outputs( + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + ) + assert output_partition is not None # Make MyPy happy. + pure_outputs, exclusive_outputs, shared_outputs = output_partition + + if len(exclusive_outputs) != 0: + self.handle_intermediate_set( + intermediate_outputs=exclusive_outputs, + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + map_exit_2=map_exit_2, + is_exclusive_set=True, + ) + if len(shared_outputs) != 0: + self.handle_intermediate_set( + intermediate_outputs=shared_outputs, + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + map_exit_2=map_exit_2, + is_exclusive_set=False, + ) + assert pure_outputs == set(graph.out_edges(map_exit_1)) + if len(pure_outputs) != 0: + self.relocate_nodes( + from_node=map_exit_1, + to_node=map_exit_2, + state=graph, + sdfg=sdfg, + ) + + # Above we have handled the input of the second map and moved them + # to the first map, now we must move the output of the first map + # to the second one, as this one is used. + self.relocate_nodes( + from_node=map_entry_2, + to_node=map_entry_1, + state=graph, + sdfg=sdfg, + ) + + for node_to_remove in [map_exit_1, map_entry_2]: + assert graph.degree(node_to_remove) == 0 + graph.remove_node(node_to_remove) + + # Now turn the second output node into the output node of the first Map. + map_exit_2.map = map_entry_1.map + + + def partition_first_outputs( + self, + state: SDFGState, + sdfg: SDFG, + map_exit_1: nodes.MapExit, + map_entry_2: nodes.MapEntry, + ) -> Union[ + Tuple[ + Set[graph.MultiConnectorEdge[dace.Memlet]], + Set[graph.MultiConnectorEdge[dace.Memlet]], + Set[graph.MultiConnectorEdge[dace.Memlet]], + ], + None, + ]: + """Partition the output edges of `map_exit_1` for serial map fusion. + + The output edges of the first map are partitioned into three distinct sets, + defined as follows: + - Pure Output Set `\mathbb{P}`: + These edges exits the first map and does not enter the second map. These + outputs will be simply be moved to the output of the second map. + - Exclusive Intermediate Set `\mathbb{E}`: + Edges in this set leaves the first map exit, enters an access node, from + where a Memlet then leads immediately to the second map. The memory + referenced by this access node is not used anywhere else, thus it can + be removed. + - Shared Intermediate Set `\mathbb{S}`: + These edges are very similar to the one in `\mathbb{E}` except that they + are used somewhere else, thus they can not be removed and have to be + recreated as output of the second map. + + If strict data flow mode is enabled the function is rather strict if an + output can be added to either intermediate set and might fail to compute + the partition, even if it would exist. + + Returns: + If such a decomposition exists the function will return the three sets + mentioned above in the same order. + In case the decomposition does not exist, i.e. the maps can not be fused + the function returns `None`. + + Args: + state: The in which the two maps are located. + sdfg: The full SDFG in whcih we operate. + map_exit_1: The exit node of the first map. + map_entry_2: The entry node of the second map. + """ + # The three outputs set. + pure_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + exclusive_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + shared_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + + # Compute the renaming that for translating the parameter of the _second_ + # map to the ones used by the first map. + repl_dict: Dict[str, str] = self.find_parameter_remapping( + first_map=map_exit_1.map, + second_map=map_entry_2.map, + ) + assert repl_dict is not None + + # Set of intermediate nodes that we have already processed. + processed_inter_nodes: Set[nodes.Node] = set() + + # These are the data that is written to multiple times in _this_ state. + # If a data is written to multiple time in a state, it could be + # classified as shared. However, it might happen that the node has zero + # degree. This is not a problem as the maps also induced a before-after + # relationship. But some DaCe transformations do not catch this. + # Thus we will never modify such intermediate nodes and fail instead. + if self.strict_dataflow: + multi_write_data: Set[str] = self._compute_multi_write_data(state, sdfg) + else: + multi_write_data = set() + + # Now scan all output edges of the first exit and classify them + for out_edge in state.out_edges(map_exit_1): + intermediate_node: nodes.Node = out_edge.dst + + # We already processed the node, this should indicate that we should + # run simplify again, or we should start implementing this case. + # TODO(phimuell): Handle this case, already partially handled here. + if intermediate_node in processed_inter_nodes: + return None + processed_inter_nodes.add(intermediate_node) + + # The intermediate can only have one incoming degree. It might be possible + # to handle multiple incoming edges, if they all come from the top map. + # However, the resulting SDFG might be invalid. + # NOTE: Allow this to happen (under certain cases) if the only producer + # is the top map. + if state.in_degree(intermediate_node) != 1: + return None + + # If the second map is not reachable from the intermediate node, then + # the output is pure and we can end here. + if not self.is_node_reachable_from( + graph=state, + begin=intermediate_node, + end=map_entry_2, + ): + pure_outputs.add(out_edge) + continue + + # The following tests are _after_ we have determined if we have a pure + # output node, because this allows us to handle more exotic pure node + # cases, as handling them is essentially rerouting an edge, whereas + # handling intermediate nodes is much more complicated. + + # For us an intermediate node must always be an access node, because + # everything else we do not know how to handle. It is important that + # we do not test for non transient data here, because they can be + # handled has shared intermediates. + if not isinstance(intermediate_node, nodes.AccessNode): + return None + if self.is_view(intermediate_node, sdfg): + return None + + # Checks if the intermediate node refers to data that is accessed by + # _other_ access nodes in _this_ state. If this is the case then never + # touch this intermediate node. + # TODO(phimuell): Technically it would be enough to turn the node into + # a shared output node, because this will still fulfil the dependencies. + # However, some DaCe transformation can not handle this properly, so we + # are _forced_ to reject this node. + if intermediate_node.data in multi_write_data: + return None + + # Empty Memlets are only allowed if they are in `\mathbb{P}`, which + # is also the only place they really make sense (for a map exit). + # Thus if we now found an empty Memlet we reject it. + if out_edge.data.is_empty(): + return None + + # It can happen that multiple edges converges at the `IN_` connector + # of the first map exit, but there is only one edge leaving the exit. + # It is complicate to handle this, so for now we ignore it. + # TODO(phimuell): Handle this case properly. + # To handle this we need to associate a consumer edge (the outgoing edges + # of the second map) with exactly one producer. + producer_edges: List[graph.MultiConnectorEdge[dace.Memlet]] = list( + state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:])) + if len(producer_edges) > 1: + return None + + # Now check the constraints we have on the producers. + # - The source of the producer can not be a view (we do not handle this) + # - The edge shall also not be a reduction edge. + # - Defined location to where they write. + # - No dynamic Melets. + # Furthermore, we will also extract the subsets, i.e. the location they + # modify inside the intermediate array. + # Since we do not allow for WCR, we do not check if the producer subsets intersects. + producer_subsets: List[subsets.Subset] = [] + for producer_edge in producer_edges: + if isinstance(producer_edge.src, nodes.AccessNode) and self.is_view(producer_edge.src, sdfg): + return None + if producer_edge.data.dynamic: + return None + if producer_edge.data.wcr is not None: + return None + if producer_edge.data.dst_subset is None: + return None + producer_subsets.append(producer_edge.data.dst_subset) + + # Check if the producer do not intersect + if len(producer_subsets) == 1: + pass + elif len(producer_subsets) == 2: + if producer_subsets[0].intersects(producer_subsets[1]): + return None + else: + for i, psbs1 in enumerate(producer_subsets): + for j, psbs2 in enumerate(producer_subsets): + if i == j: + continue + if psbs1.intersects(psbs2): + return None + + # Now we determine the consumer of nodes. For this we are using the edges + # leaves the second map entry. It is not necessary to find the actual + # consumer nodes, as they might depend on symbols of nested Maps. + # For the covering test we only need their subsets, but we will perform + # some scan and filtering on them. + found_second_map = False + consumer_subsets: List[subsets.Subset] = [] + for intermediate_node_out_edge in state.out_edges(intermediate_node): + + # If the second map entry is not immediately reachable from the intermediate + # node, then ensure that there is not path that goes to it. + if intermediate_node_out_edge.dst is not map_entry_2: + if self.is_node_reachable_from(graph=state, begin=intermediate_node_out_edge.dst, end=map_entry_2): + return None + continue + + # Ensure that the second map is found exactly once. + # TODO(phimuell): Lift this restriction. + if found_second_map: + return None + found_second_map = True + + # The output of the top map can not define a dynamic map range in the + # second map. + if not intermediate_node_out_edge.dst_conn.startswith("IN_"): + return None + + # Now we look at all edges that leave the second map entry, i.e. the + # edges that feeds the consumer and define what is read inside the map. + # We do not check them, but collect them and inspect them. + # NOTE: The subset still uses the old iteration variables. + for inner_consumer_edge in state.out_edges_by_connector( + map_entry_2, "OUT_" + intermediate_node_out_edge.dst_conn[3:]): + if inner_consumer_edge.data.src_subset is None: + return None + if inner_consumer_edge.data.dynamic: + # TODO(phimuell): Is this restriction necessary, I am not sure. + return None + consumer_subsets.append(inner_consumer_edge.data.src_subset) + assert found_second_map, f"Found '{intermediate_node}' which looked like a pure node, but is not one." + assert len(consumer_subsets) != 0 + + # The consumer still uses the original symbols of the second map, so we must rename them. + if repl_dict: + consumer_subsets = copy.deepcopy(consumer_subsets) + for consumer_subset in consumer_subsets: + symbolic.safe_replace(mapping=repl_dict, replace_callback=consumer_subset.replace) + + # Now we are checking if a single iteration of the first (top) map + # can satisfy all data requirements of the second (bottom) map. + # For this we look if the producer covers the consumer. A consumer must + # be covered by exactly one producer. + for consumer_subset in consumer_subsets: + nb_coverings = sum(producer_subset.covers(consumer_subset) for producer_subset in producer_subsets) + if nb_coverings != 1: + return None + + # After we have ensured coverage, we have to decide if the intermediate + # node can be removed (`\mathbb{E}`) or has to be restored (`\mathbb{S}`). + # Note that "removed" here means that it is reconstructed by a new + # output of the second map. + if self.is_shared_data(intermediate_node, sdfg): + # The intermediate data is used somewhere else, either in this or another state. + shared_outputs.add(out_edge) + else: + # The intermediate can be removed, as it is not used anywhere else. + exclusive_outputs.add(out_edge) + + assert len(processed_inter_nodes) == sum(len(x) for x in [pure_outputs, exclusive_outputs, shared_outputs]) + return (pure_outputs, exclusive_outputs, shared_outputs) + + + def relocate_nodes( + self, + from_node: Union[nodes.MapExit, nodes.MapEntry], + to_node: Union[nodes.MapExit, nodes.MapEntry], + state: SDFGState, + sdfg: SDFG, + ) -> None: + """Move the connectors and edges from `from_node` to `to_nodes` node. + + This function will only rewire the edges, it does not remove the nodes + themselves. Furthermore, this function should be called twice per Map, + once for the entry and then for the exit. + While it does not remove the node themselves if guarantees that the + `from_node` has degree zero. + The function assumes that the parameter renaming was already done. + + Args: + from_node: Node from which the edges should be removed. + to_node: Node to which the edges should reconnect. + state: The state in which the operation happens. + sdfg: The SDFG that is modified. + """ + + # Now we relocate empty Memlets, from the `from_node` to the `to_node` + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.out_edges(from_node))): + helpers.redirect_edge(state, empty_edge, new_src=to_node) + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.in_edges(from_node))): + helpers.redirect_edge(state, empty_edge, new_dst=to_node) + + # We now ensure that there is only one empty Memlet from the `to_node` to any other node. + # Although it is allowed, we try to prevent it. + empty_targets: Set[nodes.Node] = set() + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.all_edges(to_node))): + if empty_edge.dst in empty_targets: + state.remove_edge(empty_edge) + empty_targets.add(empty_edge.dst) + + # We now determine which edges we have to migrate, for this we are looking at + # the incoming edges, because this allows us also to detect dynamic map ranges. + # TODO(phimuell): If there is already a connection to the node, reuse this. + for edge_to_move in list(state.in_edges(from_node)): + assert isinstance(edge_to_move.dst_conn, str) + + if not edge_to_move.dst_conn.startswith("IN_"): + # Dynamic Map Range + # The connector name simply defines a variable name that is used, + # inside the Map scope to define a variable. We handle it directly. + dmr_symbol = edge_to_move.dst_conn + + # TODO(phimuell): Check if the symbol is really unused in the target scope. + if dmr_symbol in to_node.in_connectors: + raise NotImplementedError(f"Tried to move the dynamic map range '{dmr_symbol}' from {from_node}'" + f" to '{to_node}', but the symbol is already known there, but the" + " renaming is not implemented.") + if not to_node.add_in_connector(dmr_symbol, force=False): + raise RuntimeError( # Might fail because of out connectors. + f"Failed to add the dynamic map range symbol '{dmr_symbol}' to '{to_node}'.") + helpers.redirect_edge(state=state, edge=edge_to_move, new_dst=to_node) + from_node.remove_in_connector(dmr_symbol) + + else: + # We have a Passthrough connection, i.e. there exists a matching `OUT_`. + old_conn = edge_to_move.dst_conn[3:] # The connection name without prefix + new_conn = to_node.next_connector(old_conn) + + to_node.add_in_connector("IN_" + new_conn) + for e in list(state.in_edges_by_connector(from_node, "IN_" + old_conn)): + helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) + to_node.add_out_connector("OUT_" + new_conn) + for e in list(state.out_edges_by_connector(from_node, "OUT_" + old_conn)): + helpers.redirect_edge(state, e, new_src=to_node, new_src_conn="OUT_" + new_conn) + from_node.remove_in_connector("IN_" + old_conn) + from_node.remove_out_connector("OUT_" + old_conn) + + # Check if we succeeded. + if state.out_degree(from_node) != 0: + raise validation.InvalidSDFGError( + f"Failed to relocate the outgoing edges from `{from_node}`, there are still `{state.out_edges(from_node)}`", + sdfg, + sdfg.node_id(state), + ) + if state.in_degree(from_node) != 0: + raise validation.InvalidSDFGError( + f"Failed to relocate the incoming edges from `{from_node}`, there are still `{state.in_edges(from_node)}`", + sdfg, + sdfg.node_id(state), + ) + assert len(from_node.in_connectors) == 0 + assert len(from_node.out_connectors) == 0 + + + def handle_intermediate_set( + self, + intermediate_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]], + state: SDFGState, + sdfg: SDFG, + map_exit_1: nodes.MapExit, + map_entry_2: nodes.MapEntry, + map_exit_2: nodes.MapExit, + is_exclusive_set: bool, + ) -> None: + """This function handles the intermediate sets. + + The function is able to handle both the shared and exclusive intermediate + output set, see `partition_first_outputs()`. The main difference is that + in exclusive mode the intermediate nodes will be fully removed from + the SDFG. While in shared mode the intermediate node will be preserved. + The function assumes that the parameter renaming was already done. + + Args: + intermediate_outputs: The set of outputs, that should be processed. + state: The state in which the map is processed. + sdfg: The SDFG that should be optimized. + map_exit_1: The exit of the first/top map. + map_entry_2: The entry of the second map. + map_exit_2: The exit of the second map. + is_exclusive_set: If `True` `intermediate_outputs` is the exclusive set. + + Notes: + Before the transformation the `state` does not have to be valid and + after this function has run the state is (most likely) invalid. + """ + + map_params = map_exit_1.map.params.copy() + + # Now we will iterate over all intermediate edges and process them. + # If not stated otherwise the comments assume that we run in exclusive mode. + for out_edge in intermediate_outputs: + # This is the intermediate node that, that we want to get rid of. + # In shared mode we want to recreate it after the second map. + inter_node: nodes.AccessNode = out_edge.dst + inter_name = inter_node.data + inter_desc = inter_node.desc(sdfg) + inter_shape = inter_desc.shape + + # Now we will determine the shape of the new intermediate. This size of + # this temporary is given by the Memlet that goes into the first map exit. + pre_exit_edges = list(state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:])) + if len(pre_exit_edges) != 1: + raise NotImplementedError() + pre_exit_edge = pre_exit_edges[0] + new_inter_shape_raw = symbolic.overapproximate(pre_exit_edge.data.subset.size()) + + # Over approximation will leave us with some unneeded size one dimensions. + # If they are removed some dace transformations (especially auto optimization) + # will have problems. + if not self.strict_dataflow: + squeezed_dims: List[int] = [] # These are the dimensions we removed. + new_inter_shape: List[int] = [] # This is the final shape of the new intermediate. + for dim, (proposed_dim_size, full_dim_size) in enumerate(zip(new_inter_shape_raw, inter_shape)): + if full_dim_size == 1: # Must be kept! + new_inter_shape.append(proposed_dim_size) + elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. + squeezed_dims.append(dim) + else: + new_inter_shape.append(proposed_dim_size) + else: + squeezed_dims = [] + new_inter_shape = list(new_inter_shape_raw) + + # This is the name of the new "intermediate" node that we will create. + # It will only have the shape `new_inter_shape` which is basically its + # output within one Map iteration. + # NOTE: The insertion process might generate a new name. + new_inter_name: str = f"__s{sdfg.node_id(state)}_n{state.node_id(out_edge.src)}{out_edge.src_conn}_n{state.node_id(out_edge.dst)}{out_edge.dst_conn}" + + # Now generate the intermediate data container. + if len(new_inter_shape) == 0: + assert pre_exit_edge.data.subset.num_elements() == 1 + is_scalar = True + new_inter_name, new_inter_desc = sdfg.add_scalar( + new_inter_name, + dtype=inter_desc.dtype, + transient=True, + storage=dtypes.StorageType.Register, + find_new_name=True, + ) + + else: + assert (pre_exit_edge.data.subset.num_elements() > 1) or all(x == 1 for x in new_inter_shape) + is_scalar = False + new_inter_name, new_inter_desc = sdfg.add_transient( + new_inter_name, + shape=new_inter_shape, + dtype=inter_desc.dtype, + find_new_name=True, + ) + new_inter_node: nodes.AccessNode = state.add_access(new_inter_name) + + # Get the subset that defined into which part of the old intermediate + # the old output edge wrote to. We need that to adjust the producer + # Memlets, since they now write into the new (smaller) intermediate. + assert pre_exit_edge.data.data == inter_name + assert pre_exit_edge.data.dst_subset is not None + producer_offset = self.compute_offset_subset( + original_subset=pre_exit_edge.data.dst_subset, + intermediate_desc=inter_desc, + map_params=map_params, + producer_offset=None, + ) + + # Memlets have a lot of additional informations, such as dynamic. + # To ensure that we get all of them, we will now copy them and modify + # the one that was originally there. We also hope that propagate will + # set the rest for us correctly. + new_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) + new_pre_exit_memlet.data = new_inter_name + new_pre_exit_memlet.dst_subset = subsets.Range.from_array(new_inter_desc) + + # New we will reroute the output Memlet, thus it will no longer pass + # through the Map exit but through the newly created intermediate. + # NOTE: We will delete the previous edge later. + new_pre_exit_edge = state.add_edge( + pre_exit_edge.src, + pre_exit_edge.src_conn, + new_inter_node, + None, + new_pre_exit_memlet, + ) + + # We now handle the MemletTree defined by this edge. + # The newly created edge, only handled the last collection step. + for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children(include_self=False): + producer_edge = producer_tree.edge + + # Associate the (already existing) Memlet with the new data. + # TODO(phimuell): Improve the code below to remove the check. + assert producer_edge.data.data == inter_name + producer_edge.data.data = new_inter_name + + if is_scalar: + producer_edge.data.dst_subset = "0" + elif producer_edge.data.dst_subset is not None: + # Since we now write into a smaller memory patch, we must + # compensate for that. We do this by substracting where the write + # originally had begun. + producer_edge.data.dst_subset.offset(producer_offset, negative=True) + producer_edge.data.dst_subset.pop(squeezed_dims) + + # Now after we have handled the input of the new intermediate node, + # we must handle its output. For this we have to "inject" the newly + # created intermediate into the second map. We do this by finding + # the input connectors on the map entry, such that we know where we + # have to reroute inside the Map. + # NOTE: Assumes that map (if connected is the direct neighbour). + conn_names: Set[str] = set() + for inter_node_out_edge in state.out_edges(inter_node): + if inter_node_out_edge.dst == map_entry_2: + assert inter_node_out_edge.dst_conn.startswith("IN_") + conn_names.add(inter_node_out_edge.dst_conn) + else: + # If we found another target than the second map entry from the + # intermediate node it means that the node _must_ survive, + # i.e. we are not in exclusive mode. + assert not is_exclusive_set + + # Now we will reroute the connections inside the second map, i.e. + # instead of consuming the old intermediate node, they will now + # consume the new intermediate node. + for in_conn_name in conn_names: + out_conn_name = "OUT_" + in_conn_name[3:] + + for inner_edge in state.out_edges_by_connector(map_entry_2, out_conn_name): + assert inner_edge.data.data == inter_name # DIRECTION!! + + # As for the producer side, we now read from a smaller array, + # So we must offset them, we use the original edge for this. + assert inner_edge.data.src_subset is not None + consumer_offset = self.compute_offset_subset( + original_subset=inner_edge.data.src_subset, + intermediate_desc=inter_desc, + map_params=map_params, + producer_offset=producer_offset, + ) + + # Now we create a new connection that instead reads from the new + # intermediate, instead of the old one. For this we use the + # old Memlet as template. However it is not fully initialized. + new_inner_memlet = copy.deepcopy(inner_edge.data) + new_inner_memlet.data = new_inter_name + + # Now we replace the edge from the SDFG. + state.remove_edge(inner_edge) + new_inner_edge = state.add_edge( + new_inter_node, + None, + inner_edge.dst, + inner_edge.dst_conn, + new_inner_memlet, + ) + + # Now modifying the Memlet, we do it after the insertion to make + # sure that the Memlet was properly initialized. + if is_scalar: + new_inner_memlet.subset = "0" + elif new_inner_memlet.src_subset is not None: + new_inner_memlet.src_subset.offset(consumer_offset, negative=True) + new_inner_memlet.src_subset.pop(squeezed_dims) + + # Now we have to make sure that all consumers are properly updated. + for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children(include_self=False): + assert consumer_tree.edge.data.data == inter_name + + consumer_edge = consumer_tree.edge + consumer_edge.data.data = new_inter_name + if is_scalar: + consumer_edge.data.src_subset = "0" + elif consumer_edge.data.src_subset is not None: + consumer_edge.data.src_subset.offset(consumer_offset, negative=True) + consumer_edge.data.src_subset.pop(squeezed_dims) + + # The edge that leaves the second map entry was already deleted. We now delete + # the edges that connected the intermediate node with the second map entry. + for edge in list(state.in_edges_by_connector(map_entry_2, in_conn_name)): + assert edge.src == inter_node + state.remove_edge(edge) + map_entry_2.remove_in_connector(in_conn_name) + map_entry_2.remove_out_connector(out_conn_name) + + if is_exclusive_set: + # In exclusive mode the old intermediate node is no longer needed. + # This will also remove `out_edge` from the SDFG. + assert state.degree(inter_node) == 1 + state.remove_edge_and_connectors(out_edge) + state.remove_node(inter_node) + + state.remove_edge(pre_exit_edge) + map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) + map_exit_1.remove_out_connector(out_edge.src_conn) + del sdfg.arrays[inter_name] + + else: + # This is the shared mode, so we have to recreate the intermediate + # node, but this time it is at the exit of the second map. + state.remove_edge(pre_exit_edge) + map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) + + # This is the Memlet that goes from the map internal intermediate + # temporary node to the Map output. This will essentially restore + # or preserve the output for the intermediate node. It is important + # that we use the data that `preExitEdge` was used. + final_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) + assert pre_exit_edge.data.data == inter_name + final_pre_exit_memlet.other_subset = subsets.Range.from_array(new_inter_desc) + + new_pre_exit_conn = map_exit_2.next_connector() + state.add_edge( + new_inter_node, + None, + map_exit_2, + "IN_" + new_pre_exit_conn, + final_pre_exit_memlet, + ) + state.add_edge( + map_exit_2, + "OUT_" + new_pre_exit_conn, + inter_node, + out_edge.dst_conn, + copy.deepcopy(out_edge.data), + ) + map_exit_2.add_in_connector("IN_" + new_pre_exit_conn) + map_exit_2.add_out_connector("OUT_" + new_pre_exit_conn) + + map_exit_1.remove_out_connector(out_edge.src_conn) + state.remove_edge(out_edge) + + + def compute_offset_subset( + self, + original_subset: subsets.Range, + intermediate_desc: data.Data, + map_params: List[str], + producer_offset: Union[subsets.Range, None], + ) -> subsets.Range: + """Computes the memlet to correct read and writes of the intermediate. + + This is the value that must be substracted from the memlets to adjust, i.e + (`memlet_to_adjust(correction, negative=True)`). If `producer_offset` is + `None` then the function computes the correction that should be applied to + the producer memlets, i.e. the memlets of the tree converging at + `intermediate_node`. If `producer_offset` is given, it should be the output + of the previous call to this function, with `producer_offset=None`. In this + case the function computes the correction for the consumer side, i.e. the + memlet tree that originates at `intermediate_desc`. + + Args: + original_subset: The original subset that was used to write into the + intermediate, must be renamed to the final map parameter. + intermediate_desc: The original intermediate data descriptor. + map_params: The parameter of the final map. + producer_offset: The correction that was applied to the producer side. + """ + assert not isinstance(intermediate_desc, data.View) + final_offset: subsets.Range = None + if isinstance(intermediate_desc, data.Scalar): + # If the intermediate was a scalar, then it will remain a scalar. + # Thus there is no correction that we must apply. + return subsets.Range.from_string("0") + + elif isinstance(intermediate_desc, data.Array): + basic_offsets = original_subset.min_element() + offset_list = [] + for d in range(original_subset.dims()): + d_range = subsets.Range([original_subset[d]]) + if d_range.free_symbols.intersection(map_params): + offset_list.append(d_range[0]) + else: + offset_list.append((basic_offsets[d], basic_offsets[d], 1)) + final_offset = subsets.Range(offset_list) + + else: + raise TypeError(f"Does not know how to compute the subset offset for '{type(intermediate_desc).__name__}'.") + + if producer_offset is not None: + # Here we are correcting some parts that over approximate (which partially + # does under approximate) might screw up. Consider two maps, the first + # map only writes the subset `[:, 2:6]`, thus the new intermediate will + # have shape `(1, 4)`. Now also imagine that the second map only reads + # the elements `[:, 3]`. From this we see that we can only correct the + # consumer side if we also take the producer side into consideration! + # See also the `transformations/mapfusion_test.py::test_offset_correction_*` + # tests for more. + final_offset.offset( + final_offset.offset_new( + producer_offset, + negative=True, + ), + negative=True, + ) + return final_offset + + + def can_topologically_be_fused( + self, + map_entry_1: nodes.MapEntry, + map_entry_2: nodes.MapEntry, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Performs basic checks if the maps can be fused. + + This function only checks constrains that are common between serial and + parallel map fusion process, which includes: + - The scope of the maps. + - The scheduling of the maps. + - The map parameters. + + Args: + map_entry_1: The entry of the first (in serial case the top) map. + map_exit_2: The entry of the second (in serial case the bottom) map. + graph: The SDFGState in which the maps are located. + sdfg: The SDFG itself. + permissive: Currently unused. + """ + if self.only_inner_maps and self.only_toplevel_maps: + raise ValueError("You specified both `only_inner_maps` and `only_toplevel_maps`.") + + # Ensure that both have the same schedule + if map_entry_1.map.schedule != map_entry_2.map.schedule: + return False + + # Fusing is only possible if the two entries are in the same scope. + scope = graph.scope_dict() + if scope[map_entry_1] != scope[map_entry_2]: + return False + elif self.only_inner_maps: + if scope[map_entry_1] is None: + return False + elif self.only_toplevel_maps: + if scope[map_entry_1] is not None: + return False + + # We will now check if there exists a remapping that of the map parameter + if self.find_parameter_remapping(first_map=map_entry_1.map, second_map=map_entry_2.map) is None: + return False + + return True + + + def has_read_write_dependency( + self, + map_entry_1: nodes.MapEntry, + map_entry_2: nodes.MapEntry, + state: SDFGState, + sdfg: SDFG, + ) -> bool: + """Test if there is a read write dependency between the two maps to be fused. + + The function checks two different things. + - The function will make sure that there is no read write dependency between + the input and output of the fused maps. For that it will inspect the + respective subsets. + - The second part partially checks the intermediate nodes, it mostly ensures + that there are not views and that they are not used as inputs or outputs + at the same time. However, the function will not check for read write + conflicts in this set, this is done in the partition function. + + Returns: + `True` if there is a conflict between the maps that can not be handled. + If there is no conflict or if the conflict can be handled `False` + is returned. + + Args: + map_entry_1: The entry node of the first map. + map_entry_2: The entry node of the second map. + state: The state on which we operate. + sdfg: The SDFG on which we operate. + """ + map_exit_1: nodes.MapExit = state.exit_node(map_entry_1) + map_exit_2: nodes.MapExit = state.exit_node(map_entry_2) + + # Get the read and write sets of the different maps, note that Views + # are not resolved yet. + access_sets: List[Dict[str, nodes.AccessNode]] = [] + for scope_node in [map_entry_1, map_exit_1, map_entry_2, map_exit_2]: + access_set: Set[nodes.AccessNode] = self.get_access_set(scope_node, state) + access_sets.append({node.data: node for node in access_set}) + # If two different access nodes of the same scoping node refers to the + # same data, then we consider this as a dependency we can not handle. + # It is only a problem for the intermediate nodes and might be possible + # to handle, but doing so is hard, so we just forbid it. + if len(access_set) != len(access_sets[-1]): + return True + read_map_1, write_map_1, read_map_2, write_map_2 = access_sets + + # It might be possible that there are views, so we have to resolve them. + # We also already get the name of the data container. + # Note that `len(real_read_map_1) <= len(read_map_1)` holds because of Views. + resolved_sets: List[Set[str]] = [] + for unresolved_set in [read_map_1, write_map_1, read_map_2, write_map_2]: + resolved_sets.append({ + self.track_view(node, state, sdfg).data if self.is_view(node, sdfg) else node.data + for node in unresolved_set.values() + }) + # If the resolved and unresolved names do not have the same length. + # Then different views point to the same location, which we forbid + if len(unresolved_set) != len(resolved_sets[-1]): + return False + real_read_map_1, real_write_map_1, real_read_map_2, real_write_map_2 = resolved_sets + + # We do not allow that the first and second map each write to the same data. + if not real_write_map_1.isdisjoint(real_write_map_2): + return True + + # If there is no overlap in what is (totally) read and written, there will be no conflict. + # This must come before the check of disjoint write. + if (real_read_map_1 | real_read_map_2).isdisjoint(real_write_map_1 | real_write_map_2): + return False + + # These are the names (unresolved) and the access nodes of the data that is used + # to transmit information between the maps. The partition function ensures that + # these nodes are directly connected to the two maps. + exchange_names: Set[str] = set(write_map_1.keys()).intersection(read_map_2.keys()) + exchange_nodes: Set[nodes.AccessNode] = set(write_map_1.values()).intersection(read_map_2.values()) + + # If the number are different then a data is accessed through multiple nodes. + if len(exchange_names) != len(exchange_nodes): + return True + assert all(exchange_node.data in exchange_names for exchange_node in exchange_nodes) + + # For simplicity we assume that the nodes used for exchange are not views. + if any(self.is_view(exchange_node, sdfg) for exchange_node in exchange_nodes): + return True + + # This is the names of the node that are used as input of the first map and + # as output of the second map. We have to ensure that there is no data + # dependency between these nodes. + fused_inout_data_names: Set[str] = set(read_map_1.keys()).intersection(write_map_2.keys()) + + # If a data container is used as input and output then it can not be a view (simplicity) + if any(self.is_view(read_map_1[name], sdfg) for name in fused_inout_data_names): + return True + + # A data container can be used as input and output. But we do not allow that + # it is also used as intermediate or exchange data. This is an important check. + if not fused_inout_data_names.isdisjoint(exchange_names): + return True + + # Get the replacement dict for changing the map variables from the subsets of + # the second map. + repl_dict = self.find_parameter_remapping(map_entry_1.map, map_exit_2.map) + + # Now we inspect if there is a read write dependency, between data that is + # used as input and output of the fused map. There is no problem is they + # are pointwise, i.e. in each iteration the same locations are accessed. + # Essentially they all boil down to `a += 1`. + for inout_data_name in fused_inout_data_names: + all_subsets: List[subsets.Subset] = [] + # The subsets that define reading are given by the first map's entry node + all_subsets.extend( + self.find_subsets( + node=read_map_1[inout_data_name], + scope_node=map_entry_1, + state=state, + sdfg=sdfg, + repl_dict=None, + )) + # While the subsets defining writing are given by the second map's exit + # node, there we also have to apply renaming. + all_subsets.extend( + self.find_subsets( + node=write_map_2[inout_data_name], + scope_node=map_exit_2, + state=state, + sdfg=sdfg, + repl_dict=repl_dict, + )) + # Now we can test if these subsets are point wise + if not self.test_if_subsets_are_point_wise(all_subsets): + return True + + # No read write dependency was found. + return False + + + def test_if_subsets_are_point_wise(self, subsets_to_check: List[subsets.Subset]) -> bool: + """Point wise means that they are all the same. + + If a series of subsets are point wise it means that all Memlets, access + the same data. This is an important property because the whole map fusion + is build upon this. + If the subsets originates from different maps, then they must have been + renamed. + + Args: + subsets_to_check: The list of subsets that should be checked. + """ + assert len(subsets_to_check) > 1 + + # We will check everything against the master subset. + master_subset = subsets_to_check[0] + for ssidx in range(1, len(subsets_to_check)): + subset = subsets_to_check[ssidx] + if isinstance(subset, subsets.Indices): + subset = subsets.Range.from_indices(subset) + # Do we also need the reverse? See below why. + if any(r != (0, 0, 1) for r in subset.offset_new(master_subset, negative=True)): + return False + else: + # The original code used `Range.offset` here, but that one had trouble + # for `r1 = 'j, 0:10'` and `r2 = 'j, 0`. The solution would be to test + # symmetrically, i.e. `r1 - r2` and `r2 - r1`. However, if we would + # have `r2_1 = 'j, 0:10'` it consider it as failing, which is not + # what we want. Thus we will use symmetric cover. + if not master_subset.covers(subset): + return False + if not subset.covers(master_subset): + return False + + # All subsets are equal to the master subset, thus they are equal to each other. + # This means that the data accesses, described by this transformation is + # point wise + return True + + + def is_shared_data( + self, + data: nodes.AccessNode, + sdfg: dace.SDFG, + ) -> bool: + """Tests if `data` is interstate data, an can not be removed. + + Interstate data is used to transmit data between multiple state or by + extension within the state. Thus it must be classified as a shared output. + This function will go through the SDFG to and collect the names of all data + container that should be classified as shared. Note that this is an over + approximation as it does not take the location into account, i.e. "is no longer + used". + + Args: + transient: The transient that should be checked. + sdfg: The SDFG containing the array. + + Note: + The function computes the this set once for every SDFG and then caches it. + There is no mechanism to detect if the cache must be evicted. However, + as long as no additional data is added, there is no problem. + """ + if sdfg not in self._shared_data: + self._compute_shared_data_in(sdfg) + return data.data in self._shared_data[sdfg] + + + def _compute_shared_data_in( + self, + sdfg: dace.SDFG, + ) -> None: + """Updates the internal set of shared data/interstate data of `self` for `sdfg`. + + See the documentation for `self.is_shared_data()` for a description. + + Args: + sdfg: The SDFG for which the set of shared data should be computed. + """ + self._shared_data[sdfg] = set(sdfg.shared_transients()) + + + def _compute_multi_write_data( + self, + state: SDFGState, + sdfg: SDFG, + ) -> Set[str]: + """Computes data inside a _single_ state, that is written multiple times. + + Essentially this function computes the set of data that does not follow + the single static assignment idiom. The function also resolves views. + If an access node, refers to a view, not only the view itself, but also + the data it refers to is added to the set. + + Args: + state: The state that should be examined. + sdfg: The SDFG object. + + Note: + This information is used by the partition function (in case strict data + flow mode is enabled), in strict data flow mode only. The current + implementation is rather simple as it only checks if a data is written + to multiple times in the same state. + """ + data_written_to: Set[str] = set() + multi_write_data: Set[str] = set() + + for access_node in state.data_nodes(): + if state.in_degree(access_node) == 0: + continue + if access_node.data in data_written_to: + multi_write_data.add(access_node.data) + elif self.is_view(access_node, sdfg): + # This is an over approximation. + multi_write_data.update([access_node.data, self.track_view(access_node, state, sdfg).data]) + data_written_to.add(access_node.data) + return multi_write_data + + + def find_parameter_remapping(self, first_map: nodes.Map, second_map: nodes.Map) -> Union[Dict[str, str], None]: + """Computes the parameter remapping for the parameters of the _second_ map. + + The returned `dict` maps the parameters of the second map (keys) to parameter + names of the first map (values). Because of how the replace function works + the `dict` describes how to replace the parameters of the second map + with parameters of the first map. + Parameters that already have the correct name and compatible range, are not + included in the return value, thus the keys and values are always different. + If no renaming at is _needed_, i.e. all parameter have the same name and range, + then the function returns an empty `dict`. + If no remapping exists, then the function will return `None`. + + Args: + first_map: The first map (these parameters will be replaced). + second_map: The second map, these parameters acts as source. + """ + + # The parameter names + first_params: List[str] = first_map.params + second_params: List[str] = second_map.params + + if len(first_params) != len(second_params): + return None + + # The ranges, however, we apply some post processing to them. + simp = lambda e: symbolic.simplify_ext(symbolic.simplify(e)) + first_rngs: Dict[str, Tuple[Any, Any, Any]] = { + param: tuple(simp(r) for r in rng) + for param, rng in zip(first_params, first_map.range) + } + second_rngs: Dict[str, Tuple[Any, Any, Any]] = { + param: tuple(simp(r) for r in rng) + for param, rng in zip(second_params, second_map.range) + } + + # Parameters of the second map that have not yet been matched to a parameter + # of the first map and vice versa. + unmapped_second_params: Set[str] = set(second_params) + unused_first_params: Set[str] = set(first_params) + + # This is the result (`second_param -> first_param`), note that if no renaming + # is needed then the parameter is not present in the mapping. + final_mapping: Dict[str, str] = {} + + # First we identify the parameters that already have the correct name. + for param in set(first_params).intersection(second_params): + first_rng = first_rngs[param] + second_rng = second_rngs[param] + + if first_rng == second_rng: + # They have the same name and the same range, this is already a match. + # Because the names are already the same, we do not have to enter them + # in the `final_mapping` + unmapped_second_params.discard(param) + unused_first_params.discard(param) + + # Check if no remapping is needed. + if len(unmapped_second_params) == 0: + return {} + + # Now we go through all the parameters that we have not mapped yet. + # All of them will result in a remapping. + for unmapped_second_param in unmapped_second_params: + second_rng = second_rngs[unmapped_second_param] + assert unmapped_second_param not in final_mapping + + # Now look in all not yet used parameters of the first map which to use. + for candidate_param in unused_first_params: + candidate_rng = first_rngs[candidate_param] + if candidate_rng == second_rng: + final_mapping[unmapped_second_param] = candidate_param + unused_first_params.discard(candidate_param) + break + else: + # We did not find a candidate, so the remapping does not exist + return None + + assert len(unused_first_params) == 0 + assert len(final_mapping) == len(unmapped_second_params) + return final_mapping + + + def rename_map_parameters( + self, + first_map: nodes.Map, + second_map: nodes.Map, + second_map_entry: nodes.MapEntry, + state: SDFGState, + ) -> None: + """Replaces the map parameters of the second map with names from the first. + + The replacement is done in a safe way, thus `{'i': 'j', 'j': 'i'}` is + handled correct. The function assumes that a proper replacement exists. + The replacement is computed by calling `self.find_parameter_remapping()`. + + Args: + first_map: The first map (these are the final parameter). + second_map: The second map, this map will be replaced. + second_map_entry: The entry node of the second map. + state: The SDFGState on which we operate. + """ + # Compute the replacement dict. + repl_dict: Dict[str, str] = self.find_parameter_remapping(first_map=first_map, second_map=second_map) + + if repl_dict is None: + raise RuntimeError("The replacement does not exist") + if len(repl_dict) == 0: + return + + second_map_scope = state.scope_subgraph(entry_node=second_map_entry) + # Why is this thing is symbolic and not in replace? + symbolic.safe_replace( + mapping=repl_dict, + replace_callback=second_map_scope.replace_dict, + ) + + # For some odd reason the replace function does not modify the range and + # parameter of the map, so we will do it the hard way. + second_map.params = copy.deepcopy(first_map.params) + second_map.range = copy.deepcopy(first_map.range) + + + def is_node_reachable_from( + self, + graph: Union[dace.SDFG, dace.SDFGState], + begin: nodes.Node, + end: nodes.Node, + ) -> bool: + """Test if the node `end` can be reached from `begin`. + + Essentially the function starts a DFS at `begin`. If an edge is found that lead + to `end` the function returns `True`. If the node is never found `False` is + returned. + + Args: + graph: The graph to operate on. + begin: The start of the DFS. + end: The node that should be located. + """ + + def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: + return (edge.dst for edge in graph.out_edges(node)) + + to_visit: List[nodes.Node] = [begin] + seen: Set[nodes.Node] = set() + + while len(to_visit) > 0: + node: nodes.Node = to_visit.pop() + if node == end: + return True + elif node not in seen: + to_visit.extend(next_nodes(node)) + seen.add(node) + + # We never found `end` + return False + + + def get_access_set( + self, + scope_node: Union[nodes.MapEntry, nodes.MapExit], + state: SDFGState, + ) -> Set[nodes.AccessNode]: + """Computes the access set of a "scope node". + + If `scope_node` is a `MapEntry` it will operate on the set of incoming edges + and if it is an `MapExit` on the set of outgoing edges. The function will + then determine all access nodes that have a connection through these edges + to the scope nodes (edges that does not lead to access nodes are ignored). + The function returns a set that contains all access nodes that were found. + It is important that this set will also contain views. + + Args: + scope_node: The scope node that should be evaluated. + state: The state in which we operate. + """ + if isinstance(scope_node, nodes.MapEntry): + get_edges = lambda node: state.in_edges(node) + other_node = lambda e: e.src + else: + get_edges = lambda node: state.out_edges(node) + other_node = lambda e: e.dst + access_set: Set[nodes.AccessNode] = { + node + for node in map(other_node, get_edges(scope_node)) if isinstance(node, nodes.AccessNode) + } + + return access_set + + + def find_subsets( + self, + node: nodes.AccessNode, + scope_node: Union[nodes.MapExit, nodes.MapEntry], + state: SDFGState, + sdfg: SDFG, + repl_dict: Optional[Dict[str, str]], + ) -> List[subsets.Subset]: + """Finds all subsets that access `node` within `scope_node`. + + The function will not start a search for all consumer/producers. + Instead it will locate the edges which is immediately inside the + map scope. + + Args: + node: The access node that should be examined. + scope_node: We are only interested in data that flows through this node. + state: The state in which we operate. + sdfg: The SDFG object. + """ + + # Is the node used for reading or for writing. + # This influences how we have to proceed. + if isinstance(scope_node, nodes.MapEntry): + outer_edges_to_inspect = [e for e in state.in_edges(scope_node) if e.src == node] + get_subset = lambda e: e.data.src_subset + get_inner_edges = lambda e: state.out_edges_by_connector(scope_node, "OUT_" + e.dst_conn[3:]) + else: + outer_edges_to_inspect = [e for e in state.out_edges(scope_node) if e.dst == node] + get_subset = lambda e: e.data.dst_subset + get_inner_edges = lambda e: state.in_edges_by_connector(scope_node, "IN_" + e.src_conn[4:]) + + found_subsets: List[subsets.Subset] = [] + for edge in outer_edges_to_inspect: + found_subsets.extend(get_subset(e) for e in get_inner_edges(edge)) + assert len(found_subsets) > 0, "Could not find any subsets." + assert not any(subset is None for subset in found_subsets) + + found_subsets = copy.deepcopy(found_subsets) + if repl_dict: + for subset in found_subsets: + # Replace happens in place + symbolic.safe_replace(repl_dict, subset.replace) + + return found_subsets + + + def is_view( + self, + node: nodes.AccessNode, + sdfg: SDFG, + ) -> bool: + """Tests if `node` points to a view or not.""" + node_desc: data.Data = node.desc(sdfg) + return isinstance(node_desc, data.View) + + + def track_view( + self, + view: nodes.AccessNode, + state: SDFGState, + sdfg: SDFG, + ) -> nodes.AccessNode: + """Find the original data of a View. + + Given the View `view`, the function will trace the view back to the original + access node. For convenience, if `view` is not a `View` the argument will be + returned. + + Args: + view: The view that should be traced. + state: The state in which we operate. + sdfg: The SDFG on which we operate. + """ + + # Test if it is a view at all, if not return the passed node as source. + if not self.is_view(view, sdfg): + return view + + # First determine if the view is used for reading or writing. + curr_edge = dace.sdfg.utils.get_view_edge(state, view) + if curr_edge is None: + raise RuntimeError(f"Failed to determine the direction of the view '{view}'.") + if curr_edge.dst_conn == "views": + # The view is used for reading. + next_node = lambda curr_edge: curr_edge.src + elif curr_edge.src_conn == "views": + # The view is used for writing. + next_node = lambda curr_edge: curr_edge.dst + else: + raise RuntimeError(f"Failed to determine the direction of the view '{view}' | {curr_edge}.") + + # Now trace the view back. + org_view = view + view = next_node(curr_edge) + while self.is_view(view, sdfg): + curr_edge = dace.sdfg.utils.get_view_edge(state, view) + if curr_edge is None: + raise RuntimeError(f"View tracing of '{org_view}' failed at note '{view}'.") + view = next_node(curr_edge) + return view diff --git a/dace/transformation/dataflow/map_fusion_helper.py b/dace/transformation/dataflow/map_fusion_helper.py deleted file mode 100644 index deadeee5b4..0000000000 --- a/dace/transformation/dataflow/map_fusion_helper.py +++ /dev/null @@ -1,632 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -"""Implements Helper functionaliyies for map fusion""" - -import copy -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union - -import dace -from dace import data, properties, subsets, symbolic, transformation -from dace.sdfg import SDFG, SDFGState, nodes, validation -from dace.transformation import helpers - - -@properties.make_properties -class MapFusionHelper(transformation.SingleStateTransformation): - """Common parts of the parallel and serial map fusion transformation. - - Args: - only_inner_maps: Only match Maps that are internal, i.e. inside another Map. - only_toplevel_maps: Only consider Maps that are at the top. - strict_dataflow: If `True`, the transformation ensures a more - stricter version of the data flow. - - Note: - If `strict_dataflow` mode is enabled then the transformation will not remove - _direct_ data flow dependency from the graph. Furthermore, the transformation - will not remove size 1 dimensions of intermediate it crates. - This is a compatibility mode, that will limit the applicability of the - transformation, but might help transformations that does not fully analyse - the graph. - """ - - only_toplevel_maps = properties.Property( - dtype=bool, - default=False, - desc="Only perform fusing if the Maps are in the top level.", - ) - only_inner_maps = properties.Property( - dtype=bool, - default=False, - desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", - ) - strict_dataflow = properties.Property( - dtype=bool, - default=False, - desc="If `True` then the transformation will ensure a more stricter data flow.", - ) - # Maps SDFGs to the set of data that can not be removed, - # because they transmit data _between states_, such data will be made 'shared'. - # This variable acts as a cache, and is managed by 'is_shared_data()'. - _shared_data: Dict[SDFG, Set[str]] - - def __init__( - self, - only_inner_maps: Optional[bool] = None, - only_toplevel_maps: Optional[bool] = None, - strict_dataflow: Optional[bool] = None, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - if only_toplevel_maps is not None: - self.only_toplevel_maps = bool(only_toplevel_maps) - if only_inner_maps is not None: - self.only_inner_maps = bool(only_inner_maps) - if strict_dataflow is not None: - self.strict_dataflow = bool(strict_dataflow) - self._shared_data = {} - - @classmethod - def expressions(cls) -> bool: - raise RuntimeError("The `MapFusionHelper` is not a transformation on its own.") - - def can_be_fused( - self, - map_entry_1: nodes.MapEntry, - map_entry_2: nodes.MapEntry, - graph: Union[dace.SDFGState, dace.SDFG], - sdfg: dace.SDFG, - permissive: bool = False, - ) -> bool: - """Performs basic checks if the maps can be fused. - - This function only checks constrains that are common between serial and - parallel map fusion process, which includes: - - The scope of the maps. - - The scheduling of the maps. - - The map parameters. - - Args: - map_entry_1: The entry of the first (in serial case the top) map. - map_exit_2: The entry of the second (in serial case the bottom) map. - graph: The SDFGState in which the maps are located. - sdfg: The SDFG itself. - permissive: Currently unused. - """ - if self.only_inner_maps and self.only_toplevel_maps: - raise ValueError("You specified both `only_inner_maps` and `only_toplevel_maps`.") - - # Ensure that both have the same schedule - if map_entry_1.map.schedule != map_entry_2.map.schedule: - return False - - # Fusing is only possible if the two entries are in the same scope. - scope = graph.scope_dict() - if scope[map_entry_1] != scope[map_entry_2]: - return False - elif self.only_inner_maps: - if scope[map_entry_1] is None: - return False - elif self.only_toplevel_maps: - if scope[map_entry_1] is not None: - return False - - # We will now check if there exists a remapping that of the map parameter - if self.find_parameter_remapping(first_map=map_entry_1.map, second_map=map_entry_2.map) is None: - return False - - return True - - def relocate_nodes( - self, - from_node: Union[nodes.MapExit, nodes.MapEntry], - to_node: Union[nodes.MapExit, nodes.MapEntry], - state: SDFGState, - sdfg: SDFG, - ) -> None: - """Move the connectors and edges from `from_node` to `to_nodes` node. - - This function will only rewire the edges, it does not remove the nodes - themselves. Furthermore, this function should be called twice per Map, - once for the entry and then for the exit. - While it does not remove the node themselves if guarantees that the - `from_node` has degree zero. - The function assumes that the parameter renaming was already done. - - Args: - from_node: Node from which the edges should be removed. - to_node: Node to which the edges should reconnect. - state: The state in which the operation happens. - sdfg: The SDFG that is modified. - """ - - # Now we relocate empty Memlets, from the `from_node` to the `to_node` - for empty_edge in list(filter(lambda e: e.data.is_empty(), state.out_edges(from_node))): - helpers.redirect_edge(state, empty_edge, new_src=to_node) - for empty_edge in list(filter(lambda e: e.data.is_empty(), state.in_edges(from_node))): - helpers.redirect_edge(state, empty_edge, new_dst=to_node) - - # We now ensure that there is only one empty Memlet from the `to_node` to any other node. - # Although it is allowed, we try to prevent it. - empty_targets: Set[nodes.Node] = set() - for empty_edge in list(filter(lambda e: e.data.is_empty(), state.all_edges(to_node))): - if empty_edge.dst in empty_targets: - state.remove_edge(empty_edge) - empty_targets.add(empty_edge.dst) - - # We now determine which edges we have to migrate, for this we are looking at - # the incoming edges, because this allows us also to detect dynamic map ranges. - # TODO(phimuell): If there is already a connection to the node, reuse this. - for edge_to_move in list(state.in_edges(from_node)): - assert isinstance(edge_to_move.dst_conn, str) - - if not edge_to_move.dst_conn.startswith("IN_"): - # Dynamic Map Range - # The connector name simply defines a variable name that is used, - # inside the Map scope to define a variable. We handle it directly. - dmr_symbol = edge_to_move.dst_conn - - # TODO(phimuell): Check if the symbol is really unused in the target scope. - if dmr_symbol in to_node.in_connectors: - raise NotImplementedError(f"Tried to move the dynamic map range '{dmr_symbol}' from {from_node}'" - f" to '{to_node}', but the symbol is already known there, but the" - " renaming is not implemented.") - if not to_node.add_in_connector(dmr_symbol, force=False): - raise RuntimeError( # Might fail because of out connectors. - f"Failed to add the dynamic map range symbol '{dmr_symbol}' to '{to_node}'.") - helpers.redirect_edge(state=state, edge=edge_to_move, new_dst=to_node) - from_node.remove_in_connector(dmr_symbol) - - else: - # We have a Passthrough connection, i.e. there exists a matching `OUT_`. - old_conn = edge_to_move.dst_conn[3:] # The connection name without prefix - new_conn = to_node.next_connector(old_conn) - - to_node.add_in_connector("IN_" + new_conn) - for e in list(state.in_edges_by_connector(from_node, "IN_" + old_conn)): - helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) - to_node.add_out_connector("OUT_" + new_conn) - for e in list(state.out_edges_by_connector(from_node, "OUT_" + old_conn)): - helpers.redirect_edge(state, e, new_src=to_node, new_src_conn="OUT_" + new_conn) - from_node.remove_in_connector("IN_" + old_conn) - from_node.remove_out_connector("OUT_" + old_conn) - - # Check if we succeeded. - if state.out_degree(from_node) != 0: - raise validation.InvalidSDFGError( - f"Failed to relocate the outgoing edges from `{from_node}`, there are still `{state.out_edges(from_node)}`", - sdfg, - sdfg.node_id(state), - ) - if state.in_degree(from_node) != 0: - raise validation.InvalidSDFGError( - f"Failed to relocate the incoming edges from `{from_node}`, there are still `{state.in_edges(from_node)}`", - sdfg, - sdfg.node_id(state), - ) - assert len(from_node.in_connectors) == 0 - assert len(from_node.out_connectors) == 0 - - def find_parameter_remapping(self, first_map: nodes.Map, second_map: nodes.Map) -> Union[Dict[str, str], None]: - """Computes the parameter remapping for the parameters of the _second_ map. - - The returned `dict` maps the parameters of the second map (keys) to parameter - names of the first map (values). Because of how the replace function works - the `dict` describes how to replace the parameters of the second map - with parameters of the first map. - Parameters that already have the correct name and compatible range, are not - included in the return value, thus the keys and values are always different. - If no renaming at is _needed_, i.e. all parameter have the same name and range, - then the function returns an empty `dict`. - If no remapping exists, then the function will return `None`. - - Args: - first_map: The first map (these parameters will be replaced). - second_map: The second map, these parameters acts as source. - """ - - # The parameter names - first_params: List[str] = first_map.params - second_params: List[str] = second_map.params - - if len(first_params) != len(second_params): - return None - - # The ranges, however, we apply some post processing to them. - simp = lambda e: symbolic.simplify_ext(symbolic.simplify(e)) - first_rngs: Dict[str, Tuple[Any, Any, Any]] = { - param: tuple(simp(r) for r in rng) - for param, rng in zip(first_params, first_map.range) - } - second_rngs: Dict[str, Tuple[Any, Any, Any]] = { - param: tuple(simp(r) for r in rng) - for param, rng in zip(second_params, second_map.range) - } - - # Parameters of the second map that have not yet been matched to a parameter - # of the first map and vice versa. - unmapped_second_params: Set[str] = set(second_params) - unused_first_params: Set[str] = set(first_params) - - # This is the result (`second_param -> first_param`), note that if no renaming - # is needed then the parameter is not present in the mapping. - final_mapping: Dict[str, str] = {} - - # First we identify the parameters that already have the correct name. - for param in set(first_params).intersection(second_params): - first_rng = first_rngs[param] - second_rng = second_rngs[param] - - if first_rng == second_rng: - # They have the same name and the same range, this is already a match. - # Because the names are already the same, we do not have to enter them - # in the `final_mapping` - unmapped_second_params.discard(param) - unused_first_params.discard(param) - - # Check if no remapping is needed. - if len(unmapped_second_params) == 0: - return {} - - # Now we go through all the parameters that we have not mapped yet. - # All of them will result in a remapping. - for unmapped_second_param in unmapped_second_params: - second_rng = second_rngs[unmapped_second_param] - assert unmapped_second_param not in final_mapping - - # Now look in all not yet used parameters of the first map which to use. - for candidate_param in unused_first_params: - candidate_rng = first_rngs[candidate_param] - if candidate_rng == second_rng: - final_mapping[unmapped_second_param] = candidate_param - unused_first_params.discard(candidate_param) - break - else: - # We did not find a candidate, so the remapping does not exist - return None - - assert len(unused_first_params) == 0 - assert len(final_mapping) == len(unmapped_second_params) - return final_mapping - - def rename_map_parameters( - self, - first_map: nodes.Map, - second_map: nodes.Map, - second_map_entry: nodes.MapEntry, - state: SDFGState, - ) -> None: - """Replaces the map parameters of the second map with names from the first. - - The replacement is done in a safe way, thus `{'i': 'j', 'j': 'i'}` is - handled correct. The function assumes that a proper replacement exists. - The replacement is computed by calling `self.find_parameter_remapping()`. - - Args: - first_map: The first map (these are the final parameter). - second_map: The second map, this map will be replaced. - second_map_entry: The entry node of the second map. - state: The SDFGState on which we operate. - """ - # Compute the replacement dict. - repl_dict: Dict[str, str] = self.find_parameter_remapping(first_map=first_map, second_map=second_map) - - if repl_dict is None: - raise RuntimeError("The replacement does not exist") - if len(repl_dict) == 0: - return - - second_map_scope = state.scope_subgraph(entry_node=second_map_entry) - # Why is this thing is symbolic and not in replace? - symbolic.safe_replace( - mapping=repl_dict, - replace_callback=second_map_scope.replace_dict, - ) - - # For some odd reason the replace function does not modify the range and - # parameter of the map, so we will do it the hard way. - second_map.params = copy.deepcopy(first_map.params) - second_map.range = copy.deepcopy(first_map.range) - - def is_shared_data( - self, - data: nodes.AccessNode, - sdfg: dace.SDFG, - ) -> bool: - """Tests if `data` is interstate data, an can not be removed. - - Interstate data is used to transmit data between multiple state or by - extension within the state. Thus it must be classified as a shared output. - This function will go through the SDFG to and collect the names of all data - container that should be classified as shared. Note that this is an over - approximation as it does not take the location into account, i.e. "is no longer - used". - - Args: - transient: The transient that should be checked. - sdfg: The SDFG containing the array. - - Note: - The function computes the this set once for every SDFG and then caches it. - There is no mechanism to detect if the cache must be evicted. However, - as long as no additional data is added, there is no problem. - """ - if sdfg not in self._shared_data: - self._compute_shared_data(sdfg) - return data.data in self._shared_data[sdfg] - - def _compute_shared_data( - self, - sdfg: dace.SDFG, - ) -> None: - """Updates the internal set of shared data/interstate data of `self` for `sdfg`. - - See the documentation for `self.is_shared_data()` for a description. - - Args: - sdfg: The SDFG for which the set of shared data should be computed. - """ - # Shared data of this SDFG. - shared_data: Set[str] = set() - - # All global data can not be removed, so it must always be shared. - for data_name, data_desc in sdfg.arrays.items(): - if not data_desc.transient: - shared_data.add(data_name) - elif isinstance(data_desc, dace.data.Scalar): - shared_data.add(data_name) - - # We go through all states and classify the nodes/data: - # - Data is referred to in different states. - # - The access node is a view (both have to survive). - # - Transient sink or source node. - # - The access node has output degree larger than 1 (input degrees larger - # than one, will always be partitioned as shared anyway). - prevously_seen_data: Set[str] = set() - interstate_read_symbols: Set[str] = set() - for state in sdfg.nodes(): - for access_node in state.data_nodes(): - - if access_node.data in shared_data: - # The data was already classified to be shared data - pass - - elif access_node.data in prevously_seen_data: - # We have seen this data before, either in this state or in - # a previous one, but we did not classifies it as shared back then - shared_data.add(access_node.data) - - if state.in_degree(access_node) == 0: - # (Transient) sink nodes are used in other states, or simplify - # will get rid of them. - shared_data.add(access_node.data) - - elif state.out_degree(access_node) != 1: # state.out_degree() == 0 or state.out_degree() > 1 - # The access node is either a source node (it is shared in another - # state) or the node has a degree larger than one, so it is used - # in this state somewhere else. - shared_data.add(access_node.data) - - elif self.is_view(node=access_node, sdfg=sdfg): - # To ensure that the write to the view happens, both have to be shared. - viewed_data: str = self.track_view(view=access_node, state=state, sdfg=sdfg).data - shared_data.update([access_node.data, viewed_data]) - prevously_seen_data.update([access_node.data, viewed_data]) - - else: - # The node was not classified as shared data, so we record that - # we saw it. Note that a node that was immediately classified - # as shared node will never be added to this set, but a data - # that was found twice will be inside this list. - prevously_seen_data.add(access_node.data) - - # Now we are collecting all symbols that interstate edges read from. - for edge in sdfg.edges(): - interstate_read_symbols.update(edge.data.read_symbols()) - - # We also have to keep everything the edges referrers to and is an array. - shared_data.update(interstate_read_symbols.intersection(prevously_seen_data)) - - # Update the internal cache - self._shared_data[sdfg] = shared_data - - def _compute_multi_write_data( - self, - state: SDFGState, - sdfg: SDFG, - ) -> Set[str]: - """Computes data inside a _single_ state, that is written multiple times. - - Essentially this function computes the set of data that does not follow - the single static assignment idiom. The function also resolves views. - If an access node, refers to a view, not only the view itself, but also - the data it refers to is added to the set. - - Args: - state: The state that should be examined. - sdfg: The SDFG object. - - Note: - This information is used by the partition function (in case strict data - flow mode is enabled), in strict data flow mode only. The current - implementation is rather simple as it only checks if a data is written - to multiple times in the same state. - """ - data_written_to: Set[str] = set() - multi_write_data: Set[str] = set() - - for access_node in state.data_nodes(): - if state.in_degree(access_node) == 0: - continue - if access_node.data in data_written_to: - multi_write_data.add(access_node.data) - elif self.is_view(access_node, sdfg): - # This is an over approximation. - multi_write_data.update([access_node.data, self.track_view(access_node, state, sdfg).data]) - data_written_to.add(access_node.data) - return multi_write_data - - def is_node_reachable_from( - self, - graph: Union[dace.SDFG, dace.SDFGState], - begin: nodes.Node, - end: nodes.Node, - ) -> bool: - """Test if the node `end` can be reached from `begin`. - - Essentially the function starts a DFS at `begin`. If an edge is found that lead - to `end` the function returns `True`. If the node is never found `False` is - returned. - - Args: - graph: The graph to operate on. - begin: The start of the DFS. - end: The node that should be located. - """ - - def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: - return (edge.dst for edge in graph.out_edges(node)) - - to_visit: List[nodes.Node] = [begin] - seen: Set[nodes.Node] = set() - - while len(to_visit) > 0: - node: nodes.Node = to_visit.pop() - if node == end: - return True - elif node not in seen: - to_visit.extend(next_nodes(node)) - seen.add(node) - - # We never found `end` - return False - - def get_access_set( - self, - scope_node: Union[nodes.MapEntry, nodes.MapExit], - state: SDFGState, - ) -> Set[nodes.AccessNode]: - """Computes the access set of a "scope node". - - If `scope_node` is a `MapEntry` it will operate on the set of incoming edges - and if it is an `MapExit` on the set of outgoing edges. The function will - then determine all access nodes that have a connection through these edges - to the scope nodes (edges that does not lead to access nodes are ignored). - The function returns a set that contains all access nodes that were found. - It is important that this set will also contain views. - - Args: - scope_node: The scope node that should be evaluated. - state: The state in which we operate. - """ - if isinstance(scope_node, nodes.MapEntry): - get_edges = lambda node: state.in_edges(node) - other_node = lambda e: e.src - else: - get_edges = lambda node: state.out_edges(node) - other_node = lambda e: e.dst - access_set: Set[nodes.AccessNode] = { - node - for node in map(other_node, get_edges(scope_node)) if isinstance(node, nodes.AccessNode) - } - - return access_set - - def find_subsets( - self, - node: nodes.AccessNode, - scope_node: Union[nodes.MapExit, nodes.MapEntry], - state: SDFGState, - sdfg: SDFG, - repl_dict: Optional[Dict[str, str]], - ) -> List[subsets.Subset]: - """Finds all subsets that access `node` within `scope_node`. - - The function will not start a search for all consumer/producers. - Instead it will locate the edges which is immediately inside the - map scope. - - Args: - node: The access node that should be examined. - scope_node: We are only interested in data that flows through this node. - state: The state in which we operate. - sdfg: The SDFG object. - """ - - # Is the node used for reading or for writing. - # This influences how we have to proceed. - if isinstance(scope_node, nodes.MapEntry): - outer_edges_to_inspect = [e for e in state.in_edges(scope_node) if e.src == node] - get_subset = lambda e: e.data.src_subset - get_inner_edges = lambda e: state.out_edges_by_connector(scope_node, "OUT_" + e.dst_conn[3:]) - else: - outer_edges_to_inspect = [e for e in state.out_edges(scope_node) if e.dst == node] - get_subset = lambda e: e.data.dst_subset - get_inner_edges = lambda e: state.in_edges_by_connector(scope_node, "IN_" + e.src_conn[4:]) - - found_subsets: List[subsets.Subset] = [] - for edge in outer_edges_to_inspect: - found_subsets.extend(get_subset(e) for e in get_inner_edges(edge)) - assert len(found_subsets) > 0, "Could not find any subsets." - assert not any(subset is None for subset in found_subsets) - - found_subsets = copy.deepcopy(found_subsets) - if repl_dict: - for subset in found_subsets: - # Replace happens in place - symbolic.safe_replace(repl_dict, subset.replace) - - return found_subsets - - def is_view( - self, - node: nodes.AccessNode, - sdfg: SDFG, - ) -> bool: - """Tests if `node` points to a view or not.""" - node_desc: data.Data = node.desc(sdfg) - return isinstance(node_desc, data.View) - - def track_view( - self, - view: nodes.AccessNode, - state: SDFGState, - sdfg: SDFG, - ) -> nodes.AccessNode: - """Find the original data of a View. - - Given the View `view`, the function will trace the view back to the original - access node. For convenience, if `view` is not a `View` the argument will be - returned. - - Args: - view: The view that should be traced. - state: The state in which we operate. - sdfg: The SDFG on which we operate. - """ - - # Test if it is a view at all, if not return the passed node as source. - if not self.is_view(view, sdfg): - return view - - # First determine if the view is used for reading or writing. - curr_edge = dace.sdfg.utils.get_view_edge(state, view) - if curr_edge is None: - raise RuntimeError(f"Failed to determine the direction of the view '{view}'.") - if curr_edge.dst_conn == "views": - # The view is used for reading. - next_node = lambda curr_edge: curr_edge.src - elif curr_edge.src_conn == "views": - # The view is used for writing. - next_node = lambda curr_edge: curr_edge.dst - else: - raise RuntimeError(f"Failed to determine the direction of the view '{view}' | {curr_edge}.") - - # Now trace the view back. - org_view = view - view = next_node(curr_edge) - while self.is_view(view, sdfg): - curr_edge = dace.sdfg.utils.get_view_edge(state, view) - if curr_edge is None: - raise RuntimeError(f"View tracing of '{org_view}' failed at note '{view}'.") - view = next_node(curr_edge) - return view diff --git a/dace/transformation/dataflow/map_fusion_parallel.py b/dace/transformation/dataflow/map_fusion_parallel.py deleted file mode 100644 index 41e8e3bd3d..0000000000 --- a/dace/transformation/dataflow/map_fusion_parallel.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -"""Implements the parallel map fusing transformation.""" - -from typing import Any, Optional, Set, Union - -import dace -from dace import properties, transformation -from dace.sdfg import SDFG, SDFGState, graph, nodes - -from dace.transformation.dataflow import map_fusion_helper as mfh - - -@properties.make_properties -class MapFusionParallel(mfh.MapFusionHelper): - """The `MapFusionParallel` transformation allows to merge two parallel maps. - - While the `MapFusionSerial` transformation fuses maps that are sequentially - connected through an intermediate node, this transformation is able to fuse any - two maps that are not sequential and in the same scope. - - Args: - only_if_common_ancestor: Only perform fusion if both Maps share at least one - node as direct ancestor. This will increase the locality of the merge. - only_inner_maps: Only match Maps that are internal, i.e. inside another Map. - only_toplevel_maps: Only consider Maps that are at the top. - - Note: - This transformation only matches the entry nodes of the Map, but will also - modify the exit nodes of the Maps. - """ - - map_entry_1 = transformation.transformation.PatternNode(nodes.MapEntry) - map_entry_2 = transformation.transformation.PatternNode(nodes.MapEntry) - - only_if_common_ancestor = properties.Property( - dtype=bool, - default=False, - allow_none=False, - desc="Only perform fusing if the Maps share a node as parent.", - ) - - def __init__( - self, - only_if_common_ancestor: Optional[bool] = None, - **kwargs: Any, - ) -> None: - if only_if_common_ancestor is not None: - self.only_if_common_ancestor = only_if_common_ancestor - super().__init__(**kwargs) - - @classmethod - def expressions(cls) -> Any: - # This just matches _any_ two Maps inside a state. - state = graph.OrderedMultiDiConnectorGraph() - state.add_nodes_from([cls.map_entry_1, cls.map_entry_2]) - return [state] - - def can_be_applied( - self, - graph: Union[SDFGState, SDFG], - expr_index: int, - sdfg: dace.SDFG, - permissive: bool = False, - ) -> bool: - """Checks if the fusion can be done. - - The function checks the general fusing conditions and if the maps are parallel. - """ - map_entry_1: nodes.MapEntry = self.map_entry_1 - map_entry_2: nodes.MapEntry = self.map_entry_2 - - # Check the structural properties of the maps, this will also ensure that - # the two maps are in the same scope and the parameters can be renamed - if not self.can_be_fused( - map_entry_1=map_entry_1, - map_entry_2=map_entry_2, - graph=graph, - sdfg=sdfg, - permissive=permissive, - ): - return False - - # Since the match expression matches any twp Maps, we have to ensure that - # the maps are parallel. The `can_be_fused()` function already verified - # if they are in the same scope. - if not self.is_parallel(graph=graph, node1=map_entry_1, node2=map_entry_2): - return False - - # Test if they have they share a node as direct ancestor. - if self.only_if_common_ancestor: - # This assumes that there is only one access node per data container in the state. - ancestors_1: Set[nodes.Node] = {e1.src for e1 in graph.in_edges(map_entry_1)} - if not any(e2.src in ancestors_1 for e2 in graph.in_edges(map_entry_2)): - return False - - return True - - def is_parallel( - self, - graph: SDFGState, - node1: nodes.Node, - node2: nodes.Node, - ) -> bool: - """Tests if `node1` and `node2` are parallel. - - The nodes are parallel if `node2` can not be reached from `node1` and vice versa. - - Args: - graph: The graph to traverse. - node1: The first node to check. - node2: The second node to check. - """ - - # In order to be parallel they must be in the same scope. - scope = graph.scope_dict() - if scope[node1] != scope[node2]: - return False - - # The `all_nodes_between()` function traverse the graph and returns `None` if - # `end` was not found. We have to call it twice, because we do not know - # which node is upstream if they are not parallel. - if self.is_node_reachable_from(graph=graph, begin=node1, end=node2): - return False - elif self.is_node_reachable_from(graph=graph, begin=node2, end=node1): - return False - return True - - def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: - """Performs the Map fusing. - - Essentially, the function relocate all edges from the scope nodes (`MapEntry` - and `MapExit`) of the second map to the scope nodes of the first map. - """ - - map_entry_1: nodes.MapEntry = self.map_entry_1 - map_exit_1: nodes.MapExit = graph.exit_node(map_entry_1) - map_entry_2: nodes.MapEntry = self.map_entry_2 - map_exit_2: nodes.MapExit = graph.exit_node(map_entry_2) - - # Before we do anything we perform the renaming. - self.rename_map_parameters( - first_map=map_entry_1.map, - second_map=map_entry_2.map, - second_map_entry=map_entry_2, - state=graph, - ) - - for to_node, from_node in zip((map_entry_1, map_exit_1), (map_entry_2, map_exit_2)): - self.relocate_nodes( - from_node=from_node, - to_node=to_node, - state=graph, - sdfg=sdfg, - ) - # The relocate function does not remove the node, so we must do it. - graph.remove_node(from_node) diff --git a/dace/transformation/dataflow/map_fusion_serial.py b/dace/transformation/dataflow/map_fusion_serial.py deleted file mode 100644 index f284b4c7c8..0000000000 --- a/dace/transformation/dataflow/map_fusion_serial.py +++ /dev/null @@ -1,966 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -"""Implements the serial map fusing transformation.""" - -import copy -from typing import Any, Dict, List, Optional, Set, Tuple, Union - -import dace -from dace import data, dtypes, properties, subsets, symbolic, transformation -from dace.sdfg import SDFG, SDFGState, graph, nodes - -from dace.transformation.dataflow import map_fusion_helper as mfh - - -@properties.make_properties -class MapFusionSerial(mfh.MapFusionHelper): - """Fuse two serial maps together. - - The transformation combines two maps into one that are connected through some - access nodes. Conceptually this transformation removes the exit of the first - or upper map and the entry of the lower or second map and then rewrites the - connections appropriately. Depending on the situation the transformation will - either fully remove or make the intermediate a new output of the second map. - - By default, the transformation does not use the strict data flow mode, see - `MapFusionHelper` for more, however, it might be useful in come cases to enable - it. - - Args: - only_inner_maps: Only match Maps that are internal, i.e. inside another Map. - only_toplevel_maps: Only consider Maps that are at the top. - strict_dataflow: If `True`, the transformation ensures a more - stricter version of the data flow. - - Notes: - - This transformation modifies more nodes than it matches. - - After the transformation has been applied simplify should be run to remove - some dead data flow, that was introduced to ensure validity. - - A `MapFusionSerial` obejct can be initialized and be reused. However, - after new access nodes are added to any state, it is no longer valid - to use the object. - - Todo: - - Consider the case that only shared nodes are created (thus no inspection of - the graph is needed) and make all shared. Then use the dead dataflow - elimination transformation to get rid of the ones we no longer need. - - Increase the applicability. - """ - - map_exit_1 = transformation.transformation.PatternNode(nodes.MapExit) - intermediate_access_node = transformation.transformation.PatternNode(nodes.AccessNode) - map_entry_2 = transformation.transformation.PatternNode(nodes.MapEntry) - - def __init__( - self, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - - @classmethod - def expressions(cls) -> Any: - """Get the match expression. - - The transformation matches the exit node of the top Map that is connected to - an access node that again is connected to the entry node of the second Map. - An important note is, that the transformation operates not just on the - matched nodes, but more or less on anything that has an incoming connection - from the first Map or an outgoing connection to the second Map entry. - """ - return [dace.sdfg.utils.node_path_graph(cls.map_exit_1, cls.intermediate_access_node, cls.map_entry_2)] - - def can_be_applied( - self, - graph: Union[SDFGState, SDFG], - expr_index: int, - sdfg: dace.SDFG, - permissive: bool = False, - ) -> bool: - """Tests if the matched Maps can be merged. - - The two Maps are mergeable iff: - - Checks general requirements, see `MapFusionHelper.can_be_fused()`. - - Tests if the decomposition exists. - - Tests if there are read write dependencies. - """ - map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit_1) - map_exit_1: nodes.MapExit = self.map_exit_1 - map_entry_2: nodes.MapEntry = self.map_entry_2 - - # This essentially test the structural properties of the two Maps. - if not self.can_be_fused(map_entry_1=map_entry_1, map_entry_2=map_entry_2, graph=graph, sdfg=sdfg): - return False - - # Test for read-write conflicts - if self.has_read_write_dependency( - map_entry_1=map_entry_1, - map_entry_2=map_entry_2, - state=graph, - sdfg=sdfg, - ): - return False - - # Two maps can be serially fused if the node decomposition exists and - # at least one of the intermediate output sets is not empty. The state - # of the pure outputs is irrelevant for serial map fusion. - output_partition = self.partition_first_outputs( - state=graph, - sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, - ) - if output_partition is None: - return False - _, exclusive_outputs, shared_outputs = output_partition - if not (exclusive_outputs or shared_outputs): - return False - return True - - def has_read_write_dependency( - self, - map_entry_1: nodes.MapEntry, - map_entry_2: nodes.MapEntry, - state: SDFGState, - sdfg: SDFG, - ) -> bool: - """Test if there is a read write dependency between the two maps to be fused. - - The function checks two different things. - - The function will make sure that there is no read write dependency between - the input and output of the fused maps. For that it will inspect the - respective subsets. - - The second part partially checks the intermediate nodes, it mostly ensures - that there are not views and that they are not used as inputs or outputs - at the same time. However, the function will not check for read write - conflicts in this set, this is done in the partition function. - - Returns: - `True` if there is a conflict between the maps that can not be handled. - If there is no conflict or if the conflict can be handled `False` - is returned. - - Args: - map_entry_1: The entry node of the first map. - map_entry_2: The entry node of the second map. - state: The state on which we operate. - sdfg: The SDFG on which we operate. - """ - map_exit_1: nodes.MapExit = state.exit_node(map_entry_1) - map_exit_2: nodes.MapExit = state.exit_node(map_entry_2) - - # Get the read and write sets of the different maps, note that Views - # are not resolved yet. - access_sets: List[Dict[str, nodes.AccessNode]] = [] - for scope_node in [map_entry_1, map_exit_1, map_entry_2, map_exit_2]: - access_set: Set[nodes.AccessNode] = self.get_access_set(scope_node, state) - access_sets.append({node.data: node for node in access_set}) - # If two different access nodes of the same scoping node refers to the - # same data, then we consider this as a dependency we can not handle. - # It is only a problem for the intermediate nodes and might be possible - # to handle, but doing so is hard, so we just forbid it. - if len(access_set) != len(access_sets[-1]): - return True - read_map_1, write_map_1, read_map_2, write_map_2 = access_sets - - # It might be possible that there are views, so we have to resolve them. - # We also already get the name of the data container. - # Note that `len(real_read_map_1) <= len(read_map_1)` holds because of Views. - resolved_sets: List[Set[str]] = [] - for unresolved_set in [read_map_1, write_map_1, read_map_2, write_map_2]: - resolved_sets.append({ - self.track_view(node, state, sdfg).data if self.is_view(node, sdfg) else node.data - for node in unresolved_set.values() - }) - # If the resolved and unresolved names do not have the same length. - # Then different views point to the same location, which we forbid - if len(unresolved_set) != len(resolved_sets[-1]): - return False - real_read_map_1, real_write_map_1, real_read_map_2, real_write_map_2 = resolved_sets - - # We do not allow that the first and second map each write to the same data. - if not real_write_map_1.isdisjoint(real_write_map_2): - return True - - # If there is no overlap in what is (totally) read and written, there will be no conflict. - # This must come before the check of disjoint write. - if (real_read_map_1 | real_read_map_2).isdisjoint(real_write_map_1 | real_write_map_2): - return False - - # These are the names (unresolved) and the access nodes of the data that is used - # to transmit information between the maps. The partition function ensures that - # these nodes are directly connected to the two maps. - exchange_names: Set[str] = set(write_map_1.keys()).intersection(read_map_2.keys()) - exchange_nodes: Set[nodes.AccessNode] = set(write_map_1.values()).intersection(read_map_2.values()) - - # If the number are different then a data is accessed through multiple nodes. - if len(exchange_names) != len(exchange_nodes): - return True - assert all(exchange_node.data in exchange_names for exchange_node in exchange_nodes) - - # For simplicity we assume that the nodes used for exchange are not views. - if any(self.is_view(exchange_node, sdfg) for exchange_node in exchange_nodes): - return True - - # This is the names of the node that are used as input of the first map and - # as output of the second map. We have to ensure that there is no data - # dependency between these nodes. - fused_inout_data_names: Set[str] = set(read_map_1.keys()).intersection(write_map_2.keys()) - - # If a data container is used as input and output then it can not be a view (simplicity) - if any(self.is_view(read_map_1[name], sdfg) for name in fused_inout_data_names): - return True - - # A data container can be used as input and output. But we do not allow that - # it is also used as intermediate or exchange data. This is an important check. - if not fused_inout_data_names.isdisjoint(exchange_names): - return True - - # Get the replacement dict for changing the map variables from the subsets of - # the second map. - repl_dict = self.find_parameter_remapping(map_entry_1.map, map_exit_2.map) - - # Now we inspect if there is a read write dependency, between data that is - # used as input and output of the fused map. There is no problem is they - # are pointwise, i.e. in each iteration the same locations are accessed. - # Essentially they all boil down to `a += 1`. - for inout_data_name in fused_inout_data_names: - all_subsets: List[subsets.Subset] = [] - # The subsets that define reading are given by the first map's entry node - all_subsets.extend( - self.find_subsets( - node=read_map_1[inout_data_name], - scope_node=map_entry_1, - state=state, - sdfg=sdfg, - repl_dict=None, - )) - # While the subsets defining writing are given by the second map's exit - # node, there we also have to apply renaming. - all_subsets.extend( - self.find_subsets( - node=write_map_2[inout_data_name], - scope_node=map_exit_2, - state=state, - sdfg=sdfg, - repl_dict=repl_dict, - )) - # Now we can test if these subsets are point wise - if not self.test_if_subsets_are_point_wise(all_subsets): - return True - - # No read write dependency was found. - return False - - def test_if_subsets_are_point_wise(self, subsets_to_check: List[subsets.Subset]) -> bool: - """Point wise means that they are all the same. - - If a series of subsets are point wise it means that all Memlets, access - the same data. This is an important property because the whole map fusion - is build upon this. - If the subsets originates from different maps, then they must have been - renamed. - - Args: - subsets_to_check: The list of subsets that should be checked. - """ - assert len(subsets_to_check) > 1 - - # We will check everything against the master subset. - master_subset = subsets_to_check[0] - for ssidx in range(1, len(subsets_to_check)): - subset = subsets_to_check[ssidx] - if isinstance(subset, subsets.Indices): - subset = subsets.Range.from_indices(subset) - # Do we also need the reverse? See below why. - if any(r != (0, 0, 1) for r in subset.offset_new(master_subset, negative=True)): - return False - else: - # The original code used `Range.offset` here, but that one had trouble - # for `r1 = 'j, 0:10'` and `r2 = 'j, 0`. The solution would be to test - # symmetrically, i.e. `r1 - r2` and `r2 - r1`. However, if we would - # have `r2_1 = 'j, 0:10'` it consider it as failing, which is not - # what we want. Thus we will use symmetric cover. - if not master_subset.covers(subset): - return False - if not subset.covers(master_subset): - return False - - # All subsets are equal to the master subset, thus they are equal to each other. - # This means that the data accesses, described by this transformation is - # point wise - return True - - def compute_offset_subset( - self, - original_subset: subsets.Range, - intermediate_desc: data.Data, - map_params: List[str], - producer_offset: Union[subsets.Range, None], - ) -> subsets.Range: - """Computes the memlet to correct read and writes of the intermediate. - - This is the value that must be substracted from the memlets to adjust, i.e - (`memlet_to_adjust(correction, negative=True)`). If `producer_offset` is - `None` then the function computes the correction that should be applied to - the producer memlets, i.e. the memlets of the tree converging at - `intermediate_node`. If `producer_offset` is given, it should be the output - of the previous call to this function, with `producer_offset=None`. In this - case the function computes the correction for the consumer side, i.e. the - memlet tree that originates at `intermediate_desc`. - - Args: - original_subset: The original subset that was used to write into the - intermediate, must be renamed to the final map parameter. - intermediate_desc: The original intermediate data descriptor. - map_params: The parameter of the final map. - producer_offset: The correction that was applied to the producer side. - """ - assert not isinstance(intermediate_desc, data.View) - final_offset: subsets.Range = None - if isinstance(intermediate_desc, data.Scalar): - # If the intermediate was a scalar, then it will remain a scalar. - # Thus there is no correction that we must apply. - return subsets.Range.from_string("0") - - elif isinstance(intermediate_desc, data.Array): - basic_offsets = original_subset.min_element() - offset_list = [] - for d in range(original_subset.dims()): - d_range = subsets.Range([original_subset[d]]) - if d_range.free_symbols.intersection(map_params): - offset_list.append(d_range[0]) - else: - offset_list.append((basic_offsets[d], basic_offsets[d], 1)) - final_offset = subsets.Range(offset_list) - - else: - raise TypeError(f"Does not know how to compute the subset offset for '{type(intermediate_desc).__name__}'.") - - if producer_offset is not None: - # Here we are correcting some parts that over approximate (which partially - # does under approximate) might screw up. Consider two maps, the first - # map only writes the subset `[:, 2:6]`, thus the new intermediate will - # have shape `(1, 4)`. Now also imagine that the second map only reads - # the elements `[:, 3]`. From this we see that we can only correct the - # consumer side if we also take the producer side into consideration! - # See also the `transformations/mapfusion_test.py::test_offset_correction_*` - # tests for more. - final_offset.offset( - final_offset.offset_new( - producer_offset, - negative=True, - ), - negative=True, - ) - return final_offset - - def partition_first_outputs( - self, - state: SDFGState, - sdfg: SDFG, - map_exit_1: nodes.MapExit, - map_entry_2: nodes.MapEntry, - ) -> Union[ - Tuple[ - Set[graph.MultiConnectorEdge[dace.Memlet]], - Set[graph.MultiConnectorEdge[dace.Memlet]], - Set[graph.MultiConnectorEdge[dace.Memlet]], - ], - None, - ]: - """Partition the output edges of `map_exit_1` for serial map fusion. - - The output edges of the first map are partitioned into three distinct sets, - defined as follows: - - Pure Output Set `\mathbb{P}`: - These edges exits the first map and does not enter the second map. These - outputs will be simply be moved to the output of the second map. - - Exclusive Intermediate Set `\mathbb{E}`: - Edges in this set leaves the first map exit, enters an access node, from - where a Memlet then leads immediately to the second map. The memory - referenced by this access node is not used anywhere else, thus it can - be removed. - - Shared Intermediate Set `\mathbb{S}`: - These edges are very similar to the one in `\mathbb{E}` except that they - are used somewhere else, thus they can not be removed and have to be - recreated as output of the second map. - - If strict data flow mode is enabled the function is rather strict if an - output can be added to either intermediate set and might fail to compute - the partition, even if it would exist. - - Returns: - If such a decomposition exists the function will return the three sets - mentioned above in the same order. - In case the decomposition does not exist, i.e. the maps can not be fused - the function returns `None`. - - Args: - state: The in which the two maps are located. - sdfg: The full SDFG in whcih we operate. - map_exit_1: The exit node of the first map. - map_entry_2: The entry node of the second map. - """ - # The three outputs set. - pure_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() - exclusive_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() - shared_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() - - # Compute the renaming that for translating the parameter of the _second_ - # map to the ones used by the first map. - repl_dict: Dict[str, str] = self.find_parameter_remapping( - first_map=map_exit_1.map, - second_map=map_entry_2.map, - ) - assert repl_dict is not None - - # Set of intermediate nodes that we have already processed. - processed_inter_nodes: Set[nodes.Node] = set() - - # These are the data that is written to multiple times in _this_ state. - # If a data is written to multiple time in a state, it could be - # classified as shared. However, it might happen that the node has zero - # degree. This is not a problem as the maps also induced a before-after - # relationship. But some DaCe transformations do not catch this. - # Thus we will never modify such intermediate nodes and fail instead. - if self.strict_dataflow: - multi_write_data: Set[str] = self._compute_multi_write_data(state, sdfg) - else: - multi_write_data = set() - - # Now scan all output edges of the first exit and classify them - for out_edge in state.out_edges(map_exit_1): - intermediate_node: nodes.Node = out_edge.dst - - # We already processed the node, this should indicate that we should - # run simplify again, or we should start implementing this case. - # TODO(phimuell): Handle this case, already partially handled here. - if intermediate_node in processed_inter_nodes: - return None - processed_inter_nodes.add(intermediate_node) - - # The intermediate can only have one incoming degree. It might be possible - # to handle multiple incoming edges, if they all come from the top map. - # However, the resulting SDFG might be invalid. - # NOTE: Allow this to happen (under certain cases) if the only producer - # is the top map. - if state.in_degree(intermediate_node) != 1: - return None - - # If the second map is not reachable from the intermediate node, then - # the output is pure and we can end here. - if not self.is_node_reachable_from( - graph=state, - begin=intermediate_node, - end=map_entry_2, - ): - pure_outputs.add(out_edge) - continue - - # The following tests are _after_ we have determined if we have a pure - # output node, because this allows us to handle more exotic pure node - # cases, as handling them is essentially rerouting an edge, whereas - # handling intermediate nodes is much more complicated. - - # For us an intermediate node must always be an access node, because - # everything else we do not know how to handle. It is important that - # we do not test for non transient data here, because they can be - # handled has shared intermediates. - if not isinstance(intermediate_node, nodes.AccessNode): - return None - if self.is_view(intermediate_node, sdfg): - return None - - # Checks if the intermediate node refers to data that is accessed by - # _other_ access nodes in _this_ state. If this is the case then never - # touch this intermediate node. - # TODO(phimuell): Technically it would be enough to turn the node into - # a shared output node, because this will still fulfil the dependencies. - # However, some DaCe transformation can not handle this properly, so we - # are _forced_ to reject this node. - if intermediate_node.data in multi_write_data: - return None - - # Empty Memlets are only allowed if they are in `\mathbb{P}`, which - # is also the only place they really make sense (for a map exit). - # Thus if we now found an empty Memlet we reject it. - if out_edge.data.is_empty(): - return None - - # It can happen that multiple edges converges at the `IN_` connector - # of the first map exit, but there is only one edge leaving the exit. - # It is complicate to handle this, so for now we ignore it. - # TODO(phimuell): Handle this case properly. - # To handle this we need to associate a consumer edge (the outgoing edges - # of the second map) with exactly one producer. - producer_edges: List[graph.MultiConnectorEdge[dace.Memlet]] = list( - state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:])) - if len(producer_edges) > 1: - return None - - # Now check the constraints we have on the producers. - # - The source of the producer can not be a view (we do not handle this) - # - The edge shall also not be a reduction edge. - # - Defined location to where they write. - # - No dynamic Melets. - # Furthermore, we will also extract the subsets, i.e. the location they - # modify inside the intermediate array. - # Since we do not allow for WCR, we do not check if the producer subsets intersects. - producer_subsets: List[subsets.Subset] = [] - for producer_edge in producer_edges: - if isinstance(producer_edge.src, nodes.AccessNode) and self.is_view(producer_edge.src, sdfg): - return None - if producer_edge.data.dynamic: - return None - if producer_edge.data.wcr is not None: - return None - if producer_edge.data.dst_subset is None: - return None - producer_subsets.append(producer_edge.data.dst_subset) - - # Check if the producer do not intersect - if len(producer_subsets) == 1: - pass - elif len(producer_subsets) == 2: - if producer_subsets[0].intersects(producer_subsets[1]): - return None - else: - for i, psbs1 in enumerate(producer_subsets): - for j, psbs2 in enumerate(producer_subsets): - if i == j: - continue - if psbs1.intersects(psbs2): - return None - - # Now we determine the consumer of nodes. For this we are using the edges - # leaves the second map entry. It is not necessary to find the actual - # consumer nodes, as they might depend on symbols of nested Maps. - # For the covering test we only need their subsets, but we will perform - # some scan and filtering on them. - found_second_map = False - consumer_subsets: List[subsets.Subset] = [] - for intermediate_node_out_edge in state.out_edges(intermediate_node): - - # If the second map entry is not immediately reachable from the intermediate - # node, then ensure that there is not path that goes to it. - if intermediate_node_out_edge.dst is not map_entry_2: - if self.is_node_reachable_from(graph=state, begin=intermediate_node_out_edge.dst, end=map_entry_2): - return None - continue - - # Ensure that the second map is found exactly once. - # TODO(phimuell): Lift this restriction. - if found_second_map: - return None - found_second_map = True - - # The output of the top map can not define a dynamic map range in the - # second map. - if not intermediate_node_out_edge.dst_conn.startswith("IN_"): - return None - - # Now we look at all edges that leave the second map entry, i.e. the - # edges that feeds the consumer and define what is read inside the map. - # We do not check them, but collect them and inspect them. - # NOTE: The subset still uses the old iteration variables. - for inner_consumer_edge in state.out_edges_by_connector( - map_entry_2, "OUT_" + intermediate_node_out_edge.dst_conn[3:]): - if inner_consumer_edge.data.src_subset is None: - return None - if inner_consumer_edge.data.dynamic: - # TODO(phimuell): Is this restriction necessary, I am not sure. - return None - consumer_subsets.append(inner_consumer_edge.data.src_subset) - assert found_second_map, f"Found '{intermediate_node}' which looked like a pure node, but is not one." - assert len(consumer_subsets) != 0 - - # The consumer still uses the original symbols of the second map, so we must rename them. - if repl_dict: - consumer_subsets = copy.deepcopy(consumer_subsets) - for consumer_subset in consumer_subsets: - symbolic.safe_replace(mapping=repl_dict, replace_callback=consumer_subset.replace) - - # Now we are checking if a single iteration of the first (top) map - # can satisfy all data requirements of the second (bottom) map. - # For this we look if the producer covers the consumer. A consumer must - # be covered by exactly one producer. - for consumer_subset in consumer_subsets: - nb_coverings = sum(producer_subset.covers(consumer_subset) for producer_subset in producer_subsets) - if nb_coverings != 1: - return None - - # After we have ensured coverage, we have to decide if the intermediate - # node can be removed (`\mathbb{E}`) or has to be restored (`\mathbb{S}`). - # Note that "removed" here means that it is reconstructed by a new - # output of the second map. - if self.is_shared_data(intermediate_node, sdfg): - # The intermediate data is used somewhere else, either in this or another state. - shared_outputs.add(out_edge) - else: - # The intermediate can be removed, as it is not used anywhere else. - exclusive_outputs.add(out_edge) - - assert len(processed_inter_nodes) == sum(len(x) for x in [pure_outputs, exclusive_outputs, shared_outputs]) - return (pure_outputs, exclusive_outputs, shared_outputs) - - def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: - """Performs the serial Map fusing. - - The function first computes the map decomposition and then handles the - three sets. The pure outputs are handled by `relocate_nodes()` while - the two intermediate sets are handled by `handle_intermediate_set()`. - - By assumption we do not have to rename anything. - - Args: - graph: The SDFG state we are operating on. - sdfg: The SDFG we are operating on. - """ - # NOTE: `self.map_*` actually stores the ID of the node. - # once we start adding and removing nodes it seems that their ID changes. - # Thus we have to save them here, this is a known behaviour in DaCe. - assert isinstance(graph, dace.SDFGState) - assert isinstance(self.map_exit_1, nodes.MapExit) - assert isinstance(self.map_entry_2, nodes.MapEntry) - - map_exit_1: nodes.MapExit = self.map_exit_1 - map_entry_2: nodes.MapEntry = self.map_entry_2 - map_exit_2: nodes.MapExit = graph.exit_node(self.map_entry_2) - map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit_1) - - # Before we do anything we perform the renaming. - self.rename_map_parameters( - first_map=map_exit_1.map, - second_map=map_entry_2.map, - second_map_entry=map_entry_2, - state=graph, - ) - - output_partition = self.partition_first_outputs( - state=graph, - sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, - ) - assert output_partition is not None # Make MyPy happy. - pure_outputs, exclusive_outputs, shared_outputs = output_partition - - if len(exclusive_outputs) != 0: - self.handle_intermediate_set( - intermediate_outputs=exclusive_outputs, - state=graph, - sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, - map_exit_2=map_exit_2, - is_exclusive_set=True, - ) - if len(shared_outputs) != 0: - self.handle_intermediate_set( - intermediate_outputs=shared_outputs, - state=graph, - sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, - map_exit_2=map_exit_2, - is_exclusive_set=False, - ) - assert pure_outputs == set(graph.out_edges(map_exit_1)) - if len(pure_outputs) != 0: - self.relocate_nodes( - from_node=map_exit_1, - to_node=map_exit_2, - state=graph, - sdfg=sdfg, - ) - - # Above we have handled the input of the second map and moved them - # to the first map, now we must move the output of the first map - # to the second one, as this one is used. - self.relocate_nodes( - from_node=map_entry_2, - to_node=map_entry_1, - state=graph, - sdfg=sdfg, - ) - - for node_to_remove in [map_exit_1, map_entry_2]: - assert graph.degree(node_to_remove) == 0 - graph.remove_node(node_to_remove) - - # Now turn the second output node into the output node of the first Map. - map_exit_2.map = map_entry_1.map - - def handle_intermediate_set( - self, - intermediate_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]], - state: SDFGState, - sdfg: SDFG, - map_exit_1: nodes.MapExit, - map_entry_2: nodes.MapEntry, - map_exit_2: nodes.MapExit, - is_exclusive_set: bool, - ) -> None: - """This function handles the intermediate sets. - - The function is able to handle both the shared and exclusive intermediate - output set, see `partition_first_outputs()`. The main difference is that - in exclusive mode the intermediate nodes will be fully removed from - the SDFG. While in shared mode the intermediate node will be preserved. - The function assumes that the parameter renaming was already done. - - Args: - intermediate_outputs: The set of outputs, that should be processed. - state: The state in which the map is processed. - sdfg: The SDFG that should be optimized. - map_exit_1: The exit of the first/top map. - map_entry_2: The entry of the second map. - map_exit_2: The exit of the second map. - is_exclusive_set: If `True` `intermediate_outputs` is the exclusive set. - - Notes: - Before the transformation the `state` does not have to be valid and - after this function has run the state is (most likely) invalid. - """ - - map_params = map_exit_1.map.params.copy() - - # Now we will iterate over all intermediate edges and process them. - # If not stated otherwise the comments assume that we run in exclusive mode. - for out_edge in intermediate_outputs: - # This is the intermediate node that, that we want to get rid of. - # In shared mode we want to recreate it after the second map. - inter_node: nodes.AccessNode = out_edge.dst - inter_name = inter_node.data - inter_desc = inter_node.desc(sdfg) - inter_shape = inter_desc.shape - - # Now we will determine the shape of the new intermediate. This size of - # this temporary is given by the Memlet that goes into the first map exit. - pre_exit_edges = list(state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:])) - if len(pre_exit_edges) != 1: - raise NotImplementedError() - pre_exit_edge = pre_exit_edges[0] - new_inter_shape_raw = symbolic.overapproximate(pre_exit_edge.data.subset.size()) - - # Over approximation will leave us with some unneeded size one dimensions. - # If they are removed some dace transformations (especially auto optimization) - # will have problems. - if not self.strict_dataflow: - squeezed_dims: List[int] = [] # These are the dimensions we removed. - new_inter_shape: List[int] = [] # This is the final shape of the new intermediate. - for dim, (proposed_dim_size, full_dim_size) in enumerate(zip(new_inter_shape_raw, inter_shape)): - if full_dim_size == 1: # Must be kept! - new_inter_shape.append(proposed_dim_size) - elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. - squeezed_dims.append(dim) - else: - new_inter_shape.append(proposed_dim_size) - else: - squeezed_dims = [] - new_inter_shape = list(new_inter_shape_raw) - - # This is the name of the new "intermediate" node that we will create. - # It will only have the shape `new_inter_shape` which is basically its - # output within one Map iteration. - # NOTE: The insertion process might generate a new name. - new_inter_name: str = f"__s{sdfg.node_id(state)}_n{state.node_id(out_edge.src)}{out_edge.src_conn}_n{state.node_id(out_edge.dst)}{out_edge.dst_conn}" - - # Now generate the intermediate data container. - if len(new_inter_shape) == 0: - assert pre_exit_edge.data.subset.num_elements() == 1 - is_scalar = True - new_inter_name, new_inter_desc = sdfg.add_scalar( - new_inter_name, - dtype=inter_desc.dtype, - transient=True, - storage=dtypes.StorageType.Register, - find_new_name=True, - ) - - else: - assert (pre_exit_edge.data.subset.num_elements() > 1) or all(x == 1 for x in new_inter_shape) - is_scalar = False - new_inter_name, new_inter_desc = sdfg.add_transient( - new_inter_name, - shape=new_inter_shape, - dtype=inter_desc.dtype, - find_new_name=True, - ) - new_inter_node: nodes.AccessNode = state.add_access(new_inter_name) - - # Get the subset that defined into which part of the old intermediate - # the old output edge wrote to. We need that to adjust the producer - # Memlets, since they now write into the new (smaller) intermediate. - assert pre_exit_edge.data.data == inter_name - assert pre_exit_edge.data.dst_subset is not None - producer_offset = self.compute_offset_subset( - original_subset=pre_exit_edge.data.dst_subset, - intermediate_desc=inter_desc, - map_params=map_params, - producer_offset=None, - ) - - # Memlets have a lot of additional informations, such as dynamic. - # To ensure that we get all of them, we will now copy them and modify - # the one that was originally there. We also hope that propagate will - # set the rest for us correctly. - new_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) - new_pre_exit_memlet.data = new_inter_name - new_pre_exit_memlet.dst_subset = subsets.Range.from_array(new_inter_desc) - - # New we will reroute the output Memlet, thus it will no longer pass - # through the Map exit but through the newly created intermediate. - # NOTE: We will delete the previous edge later. - new_pre_exit_edge = state.add_edge( - pre_exit_edge.src, - pre_exit_edge.src_conn, - new_inter_node, - None, - new_pre_exit_memlet, - ) - - # We now handle the MemletTree defined by this edge. - # The newly created edge, only handled the last collection step. - for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children(include_self=False): - producer_edge = producer_tree.edge - - # Associate the (already existing) Memlet with the new data. - # TODO(phimuell): Improve the code below to remove the check. - assert producer_edge.data.data == inter_name - producer_edge.data.data = new_inter_name - - if is_scalar: - producer_edge.data.dst_subset = "0" - elif producer_edge.data.dst_subset is not None: - # Since we now write into a smaller memory patch, we must - # compensate for that. We do this by substracting where the write - # originally had begun. - producer_edge.data.dst_subset.offset(producer_offset, negative=True) - producer_edge.data.dst_subset.pop(squeezed_dims) - - # Now after we have handled the input of the new intermediate node, - # we must handle its output. For this we have to "inject" the newly - # created intermediate into the second map. We do this by finding - # the input connectors on the map entry, such that we know where we - # have to reroute inside the Map. - # NOTE: Assumes that map (if connected is the direct neighbour). - conn_names: Set[str] = set() - for inter_node_out_edge in state.out_edges(inter_node): - if inter_node_out_edge.dst == map_entry_2: - assert inter_node_out_edge.dst_conn.startswith("IN_") - conn_names.add(inter_node_out_edge.dst_conn) - else: - # If we found another target than the second map entry from the - # intermediate node it means that the node _must_ survive, - # i.e. we are not in exclusive mode. - assert not is_exclusive_set - - # Now we will reroute the connections inside the second map, i.e. - # instead of consuming the old intermediate node, they will now - # consume the new intermediate node. - for in_conn_name in conn_names: - out_conn_name = "OUT_" + in_conn_name[3:] - - for inner_edge in state.out_edges_by_connector(map_entry_2, out_conn_name): - assert inner_edge.data.data == inter_name # DIRECTION!! - - # As for the producer side, we now read from a smaller array, - # So we must offset them, we use the original edge for this. - assert inner_edge.data.src_subset is not None - consumer_offset = self.compute_offset_subset( - original_subset=inner_edge.data.src_subset, - intermediate_desc=inter_desc, - map_params=map_params, - producer_offset=producer_offset, - ) - - # Now we create a new connection that instead reads from the new - # intermediate, instead of the old one. For this we use the - # old Memlet as template. However it is not fully initialized. - new_inner_memlet = copy.deepcopy(inner_edge.data) - new_inner_memlet.data = new_inter_name - - # Now we replace the edge from the SDFG. - state.remove_edge(inner_edge) - new_inner_edge = state.add_edge( - new_inter_node, - None, - inner_edge.dst, - inner_edge.dst_conn, - new_inner_memlet, - ) - - # Now modifying the Memlet, we do it after the insertion to make - # sure that the Memlet was properly initialized. - if is_scalar: - new_inner_memlet.subset = "0" - elif new_inner_memlet.src_subset is not None: - new_inner_memlet.src_subset.offset(consumer_offset, negative=True) - new_inner_memlet.src_subset.pop(squeezed_dims) - - # Now we have to make sure that all consumers are properly updated. - for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children(include_self=False): - assert consumer_tree.edge.data.data == inter_name - - consumer_edge = consumer_tree.edge - consumer_edge.data.data = new_inter_name - if is_scalar: - consumer_edge.data.src_subset = "0" - elif consumer_edge.data.src_subset is not None: - consumer_edge.data.src_subset.offset(consumer_offset, negative=True) - consumer_edge.data.src_subset.pop(squeezed_dims) - - # The edge that leaves the second map entry was already deleted. We now delete - # the edges that connected the intermediate node with the second map entry. - for edge in list(state.in_edges_by_connector(map_entry_2, in_conn_name)): - assert edge.src == inter_node - state.remove_edge(edge) - map_entry_2.remove_in_connector(in_conn_name) - map_entry_2.remove_out_connector(out_conn_name) - - if is_exclusive_set: - # In exclusive mode the old intermediate node is no longer needed. - # This will also remove `out_edge` from the SDFG. - assert state.degree(inter_node) == 1 - state.remove_edge_and_connectors(out_edge) - state.remove_node(inter_node) - - state.remove_edge(pre_exit_edge) - map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) - map_exit_1.remove_out_connector(out_edge.src_conn) - del sdfg.arrays[inter_name] - - else: - # This is the shared mode, so we have to recreate the intermediate - # node, but this time it is at the exit of the second map. - state.remove_edge(pre_exit_edge) - map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) - - # This is the Memlet that goes from the map internal intermediate - # temporary node to the Map output. This will essentially restore - # or preserve the output for the intermediate node. It is important - # that we use the data that `preExitEdge` was used. - final_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) - assert pre_exit_edge.data.data == inter_name - final_pre_exit_memlet.other_subset = subsets.Range.from_array(new_inter_desc) - - new_pre_exit_conn = map_exit_2.next_connector() - state.add_edge( - new_inter_node, - None, - map_exit_2, - "IN_" + new_pre_exit_conn, - final_pre_exit_memlet, - ) - state.add_edge( - map_exit_2, - "OUT_" + new_pre_exit_conn, - inter_node, - out_edge.dst_conn, - copy.deepcopy(out_edge.data), - ) - map_exit_2.add_in_connector("IN_" + new_pre_exit_conn) - map_exit_2.add_out_connector("OUT_" + new_pre_exit_conn) - - map_exit_1.remove_out_connector(out_edge.src_conn) - state.remove_edge(out_edge) From 259d17c1b233ef25535c7f3f79a29d52ed48105f Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 1 Nov 2024 09:46:23 +0100 Subject: [PATCH 080/115] Added a new test. --- tests/transformations/mapfusion_test.py | 142 +++++++++++------------- 1 file changed, 66 insertions(+), 76 deletions(-) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 0702f2dfb7..42ad7c2716 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -8,7 +8,7 @@ from dace import SDFG, SDFGState from dace.sdfg import nodes -from dace.transformation.dataflow import MapFusionSerial, MapFusionParallel, MapExpansion +from dace.transformation.dataflow import MapFusion, MapExpansion def count_node(sdfg: SDFG, node_type): @@ -33,7 +33,7 @@ def apply_fusion( """ num_maps_before = count_node(sdfg, nodes.MapEntry) org_sdfg = copy.deepcopy(sdfg) - sdfg.apply_transformations_repeated(MapFusionSerial, validate=True, validate_all=True) + sdfg.apply_transformations_repeated(MapFusion, validate=True, validate_all=True) num_maps_after = count_node(sdfg, nodes.MapEntry) has_processed = False @@ -496,7 +496,7 @@ def test_interstate_fusion(): ref_C = A + 30 ref_D = A + 26 - assert sdfg.apply_transformations_repeated(MapFusionSerial, validate=True, validate_all=True) == 1 + assert sdfg.apply_transformations_repeated(MapFusion, validate=True, validate_all=True) == 1 assert sdfg.number_of_nodes() == 2 assert len([node for node in state1.data_nodes() if node.data == "B"]) == 1 @@ -506,79 +506,6 @@ def test_interstate_fusion(): assert np.allclose(D, ref_D) -def test_parallel_fusion_simple(): - N1, N2 = 10, 20 - - def _make_sdfg(): - sdfg = dace.SDFG("simple_parallel_map") - state = sdfg.add_state("state", is_start_block=True) - for name in ("A", "B", "out1", "out2"): - sdfg.add_array(name, shape=(N1, N2), transient=False, dtype=dace.float64) - sdfg.add_scalar("dmr", dtype=dace.float64, transient=False) - A, B, dmr, out1, out2 = (state.add_access(name) for name in ("A", "B", "dmr", "out1", "out2")) - - _, map1_entry, _ = state.add_mapped_tasklet( - "map_with_dynamic_range", - map_ranges={"__i0": f"0:{N1}", "__i1": f"0:{N2}"}, - inputs={"__in0": dace.Memlet("A[__i0, __i1]")}, - code="__out = __in0 + dynamic_range_value", - outputs={"__out": dace.Memlet("out1[__i0, __i1]")}, - input_nodes={"A": A}, - output_nodes={"out1": out1}, - external_edges=True, - ) - state.add_edge( - dmr, - None, - map1_entry, - "dynamic_range_value", - dace.Memlet("dmr[0]"), - ) - map1_entry.add_in_connector("dynamic_range_value") - - _, map2_entry, _ = state.add_mapped_tasklet( - "map_without_dynamic_range", - map_ranges={"__i2": f"0:{N1}", "__i3": f"0:{N2}"}, - inputs={ - "__in0": dace.Memlet("A[__i2, __i3]"), - "__in1": dace.Memlet("B[__i2, __i3]") - }, - code="__out = __in0 + __in1", - outputs={"__out": dace.Memlet("out2[__i2, __i3]")}, - input_nodes={"A": A, "B": B}, - output_nodes={"out2": out2}, - external_edges=True, - ) - sdfg.validate() - return sdfg, map1_entry, map2_entry - - for mode in range(2): - A = np.random.rand(N1, N2) - B = np.random.rand(N1, N2) - dmr = 3.1415 - out1 = np.zeros_like(A) - out2 = np.zeros_like(B) - res1 = A + dmr - res2 = A + B - - sdfg, map1_entry, map2_entry = _make_sdfg() - - if mode: - map1_entry, map2_entry = map2_entry, map1_entry - - MapFusionParallel.apply_to( - sdfg, - map_entry_1=map1_entry, - map_entry_2=map2_entry, - verify=True, - ) - assert count_node(sdfg, dace.sdfg.nodes.MapEntry) == 1 - - sdfg(A=A, B=B, dmr=dmr, out1=out1, out2=out2) - assert np.allclose(out1, res1) - assert np.allclose(out2, res2) - - def test_fuse_indirect_accesses(): @dace.program(auto_optimize=False) @@ -730,7 +657,70 @@ def test_offset_correction_empty(): apply_fusion(sdfg, removed_maps=0) +def test_different_access(): + + def exptected(A, B): + N, M = A.shape + return (A + 1) + B[1:(N+1), 2:(M+2)] + + def _make_sdfg(N: int, M: int) -> dace.SDFG: + sdfg = dace.SDFG("test_different_access") + names = ["A", "B", "__tmp", "__return"] + def_shape = (N, M) + sshape = {"B": (N+1, M+2), "__tmp": (N+1, M+1)} + for name in names: + sdfg.add_array( + name, + shape=sshape.get(name, def_shape), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["__tmp"].transient = True + + state = sdfg.add_state(is_start_block=True) + A, B, _tmp, _return = (state.add_access(name) for name in names) + + state.add_mapped_tasklet( + "comp1", + map_ranges={"__i0": f"0:{N}", "__i1": f"0:{M}"}, + inputs={"__in": dace.Memlet("A[__i0, __i1]")}, + code="__out = __in + 1.0", + outputs={"__out": dace.Memlet("__tmp[__i0 + 1, __i1 + 1]")}, + input_nodes={A}, + output_nodes={_tmp}, + external_edges=True, + ) + state.add_mapped_tasklet( + "comp2", + map_ranges={"__i0": f"0:{N}", "__i1": f"0:{M}"}, + inputs={ + "__in1": dace.Memlet("__tmp[__i0 + 1, __i1 + 1]"), + "__in2": dace.Memlet("B[__i0 + 1, __i1 + 2]"), + }, + code="__out = __in1 + __in2", + outputs={"__out": dace.Memlet("__return[__i0, __i1]")}, + input_nodes={_tmp, B}, + output_nodes={_return}, + external_edges=True, + ) + + sdfg.validate() + return sdfg + + N, M = 14, 17 + sdfg = _make_sdfg(N, M) + apply_fusion(sdfg, final_maps=1) + + A = np.array(np.random.rand(N, M), dtype=np.float64, copy=True) + B = np.array(np.random.rand(N + 1, M + 2), dtype=np.float64, copy=True) + + ref = exptected(A, B) + res = sdfg(A=A, B=B) + assert np.allclose(ref, res) + + if __name__ == '__main__': + test_different_access() test_indirect_accesses() test_fusion_shared() test_fusion_with_transient() From db263200fe7c8be0b8940e0ab7e9e7291098491b Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 1 Nov 2024 09:46:41 +0100 Subject: [PATCH 081/115] Fixed a missing include. --- dace/transformation/dataflow/map_fusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index 032cf45634..eff04bd380 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -2,7 +2,7 @@ """Implements the serial map fusing transformation.""" import copy -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union, Iterable import dace from dace import data, dtypes, properties, subsets, symbolic, transformation From fa67492b400c9f473b3b5c46e9495787c1683f65 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 1 Nov 2024 09:49:13 +0100 Subject: [PATCH 082/115] Revert "Added a new test." This reverts commit 259d17c1b233ef25535c7f3f79a29d52ed48105f. --- tests/transformations/mapfusion_test.py | 142 +++++++++++++----------- 1 file changed, 76 insertions(+), 66 deletions(-) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 42ad7c2716..0702f2dfb7 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -8,7 +8,7 @@ from dace import SDFG, SDFGState from dace.sdfg import nodes -from dace.transformation.dataflow import MapFusion, MapExpansion +from dace.transformation.dataflow import MapFusionSerial, MapFusionParallel, MapExpansion def count_node(sdfg: SDFG, node_type): @@ -33,7 +33,7 @@ def apply_fusion( """ num_maps_before = count_node(sdfg, nodes.MapEntry) org_sdfg = copy.deepcopy(sdfg) - sdfg.apply_transformations_repeated(MapFusion, validate=True, validate_all=True) + sdfg.apply_transformations_repeated(MapFusionSerial, validate=True, validate_all=True) num_maps_after = count_node(sdfg, nodes.MapEntry) has_processed = False @@ -496,7 +496,7 @@ def test_interstate_fusion(): ref_C = A + 30 ref_D = A + 26 - assert sdfg.apply_transformations_repeated(MapFusion, validate=True, validate_all=True) == 1 + assert sdfg.apply_transformations_repeated(MapFusionSerial, validate=True, validate_all=True) == 1 assert sdfg.number_of_nodes() == 2 assert len([node for node in state1.data_nodes() if node.data == "B"]) == 1 @@ -506,6 +506,79 @@ def test_interstate_fusion(): assert np.allclose(D, ref_D) +def test_parallel_fusion_simple(): + N1, N2 = 10, 20 + + def _make_sdfg(): + sdfg = dace.SDFG("simple_parallel_map") + state = sdfg.add_state("state", is_start_block=True) + for name in ("A", "B", "out1", "out2"): + sdfg.add_array(name, shape=(N1, N2), transient=False, dtype=dace.float64) + sdfg.add_scalar("dmr", dtype=dace.float64, transient=False) + A, B, dmr, out1, out2 = (state.add_access(name) for name in ("A", "B", "dmr", "out1", "out2")) + + _, map1_entry, _ = state.add_mapped_tasklet( + "map_with_dynamic_range", + map_ranges={"__i0": f"0:{N1}", "__i1": f"0:{N2}"}, + inputs={"__in0": dace.Memlet("A[__i0, __i1]")}, + code="__out = __in0 + dynamic_range_value", + outputs={"__out": dace.Memlet("out1[__i0, __i1]")}, + input_nodes={"A": A}, + output_nodes={"out1": out1}, + external_edges=True, + ) + state.add_edge( + dmr, + None, + map1_entry, + "dynamic_range_value", + dace.Memlet("dmr[0]"), + ) + map1_entry.add_in_connector("dynamic_range_value") + + _, map2_entry, _ = state.add_mapped_tasklet( + "map_without_dynamic_range", + map_ranges={"__i2": f"0:{N1}", "__i3": f"0:{N2}"}, + inputs={ + "__in0": dace.Memlet("A[__i2, __i3]"), + "__in1": dace.Memlet("B[__i2, __i3]") + }, + code="__out = __in0 + __in1", + outputs={"__out": dace.Memlet("out2[__i2, __i3]")}, + input_nodes={"A": A, "B": B}, + output_nodes={"out2": out2}, + external_edges=True, + ) + sdfg.validate() + return sdfg, map1_entry, map2_entry + + for mode in range(2): + A = np.random.rand(N1, N2) + B = np.random.rand(N1, N2) + dmr = 3.1415 + out1 = np.zeros_like(A) + out2 = np.zeros_like(B) + res1 = A + dmr + res2 = A + B + + sdfg, map1_entry, map2_entry = _make_sdfg() + + if mode: + map1_entry, map2_entry = map2_entry, map1_entry + + MapFusionParallel.apply_to( + sdfg, + map_entry_1=map1_entry, + map_entry_2=map2_entry, + verify=True, + ) + assert count_node(sdfg, dace.sdfg.nodes.MapEntry) == 1 + + sdfg(A=A, B=B, dmr=dmr, out1=out1, out2=out2) + assert np.allclose(out1, res1) + assert np.allclose(out2, res2) + + def test_fuse_indirect_accesses(): @dace.program(auto_optimize=False) @@ -657,70 +730,7 @@ def test_offset_correction_empty(): apply_fusion(sdfg, removed_maps=0) -def test_different_access(): - - def exptected(A, B): - N, M = A.shape - return (A + 1) + B[1:(N+1), 2:(M+2)] - - def _make_sdfg(N: int, M: int) -> dace.SDFG: - sdfg = dace.SDFG("test_different_access") - names = ["A", "B", "__tmp", "__return"] - def_shape = (N, M) - sshape = {"B": (N+1, M+2), "__tmp": (N+1, M+1)} - for name in names: - sdfg.add_array( - name, - shape=sshape.get(name, def_shape), - dtype=dace.float64, - transient=False, - ) - sdfg.arrays["__tmp"].transient = True - - state = sdfg.add_state(is_start_block=True) - A, B, _tmp, _return = (state.add_access(name) for name in names) - - state.add_mapped_tasklet( - "comp1", - map_ranges={"__i0": f"0:{N}", "__i1": f"0:{M}"}, - inputs={"__in": dace.Memlet("A[__i0, __i1]")}, - code="__out = __in + 1.0", - outputs={"__out": dace.Memlet("__tmp[__i0 + 1, __i1 + 1]")}, - input_nodes={A}, - output_nodes={_tmp}, - external_edges=True, - ) - state.add_mapped_tasklet( - "comp2", - map_ranges={"__i0": f"0:{N}", "__i1": f"0:{M}"}, - inputs={ - "__in1": dace.Memlet("__tmp[__i0 + 1, __i1 + 1]"), - "__in2": dace.Memlet("B[__i0 + 1, __i1 + 2]"), - }, - code="__out = __in1 + __in2", - outputs={"__out": dace.Memlet("__return[__i0, __i1]")}, - input_nodes={_tmp, B}, - output_nodes={_return}, - external_edges=True, - ) - - sdfg.validate() - return sdfg - - N, M = 14, 17 - sdfg = _make_sdfg(N, M) - apply_fusion(sdfg, final_maps=1) - - A = np.array(np.random.rand(N, M), dtype=np.float64, copy=True) - B = np.array(np.random.rand(N + 1, M + 2), dtype=np.float64, copy=True) - - ref = exptected(A, B) - res = sdfg(A=A, B=B) - assert np.allclose(ref, res) - - if __name__ == '__main__': - test_different_access() test_indirect_accesses() test_fusion_shared() test_fusion_with_transient() From 3453c6c8f608ee53625be05326cfd03218d05ad9 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 1 Nov 2024 09:51:12 +0100 Subject: [PATCH 083/115] It seems that I have removed a test. --- tests/transformations/mapfusion_test.py | 143 +++++++++++------------- 1 file changed, 66 insertions(+), 77 deletions(-) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 0702f2dfb7..eb6df63183 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -8,7 +8,7 @@ from dace import SDFG, SDFGState from dace.sdfg import nodes -from dace.transformation.dataflow import MapFusionSerial, MapFusionParallel, MapExpansion +from dace.transformation.dataflow import MapFusion, MapExpansion def count_node(sdfg: SDFG, node_type): @@ -33,7 +33,7 @@ def apply_fusion( """ num_maps_before = count_node(sdfg, nodes.MapEntry) org_sdfg = copy.deepcopy(sdfg) - sdfg.apply_transformations_repeated(MapFusionSerial, validate=True, validate_all=True) + sdfg.apply_transformations_repeated(MapFusion, validate=True, validate_all=True) num_maps_after = count_node(sdfg, nodes.MapEntry) has_processed = False @@ -496,7 +496,7 @@ def test_interstate_fusion(): ref_C = A + 30 ref_D = A + 26 - assert sdfg.apply_transformations_repeated(MapFusionSerial, validate=True, validate_all=True) == 1 + assert sdfg.apply_transformations_repeated(MapFusion, validate=True, validate_all=True) == 1 assert sdfg.number_of_nodes() == 2 assert len([node for node in state1.data_nodes() if node.data == "B"]) == 1 @@ -506,79 +506,6 @@ def test_interstate_fusion(): assert np.allclose(D, ref_D) -def test_parallel_fusion_simple(): - N1, N2 = 10, 20 - - def _make_sdfg(): - sdfg = dace.SDFG("simple_parallel_map") - state = sdfg.add_state("state", is_start_block=True) - for name in ("A", "B", "out1", "out2"): - sdfg.add_array(name, shape=(N1, N2), transient=False, dtype=dace.float64) - sdfg.add_scalar("dmr", dtype=dace.float64, transient=False) - A, B, dmr, out1, out2 = (state.add_access(name) for name in ("A", "B", "dmr", "out1", "out2")) - - _, map1_entry, _ = state.add_mapped_tasklet( - "map_with_dynamic_range", - map_ranges={"__i0": f"0:{N1}", "__i1": f"0:{N2}"}, - inputs={"__in0": dace.Memlet("A[__i0, __i1]")}, - code="__out = __in0 + dynamic_range_value", - outputs={"__out": dace.Memlet("out1[__i0, __i1]")}, - input_nodes={"A": A}, - output_nodes={"out1": out1}, - external_edges=True, - ) - state.add_edge( - dmr, - None, - map1_entry, - "dynamic_range_value", - dace.Memlet("dmr[0]"), - ) - map1_entry.add_in_connector("dynamic_range_value") - - _, map2_entry, _ = state.add_mapped_tasklet( - "map_without_dynamic_range", - map_ranges={"__i2": f"0:{N1}", "__i3": f"0:{N2}"}, - inputs={ - "__in0": dace.Memlet("A[__i2, __i3]"), - "__in1": dace.Memlet("B[__i2, __i3]") - }, - code="__out = __in0 + __in1", - outputs={"__out": dace.Memlet("out2[__i2, __i3]")}, - input_nodes={"A": A, "B": B}, - output_nodes={"out2": out2}, - external_edges=True, - ) - sdfg.validate() - return sdfg, map1_entry, map2_entry - - for mode in range(2): - A = np.random.rand(N1, N2) - B = np.random.rand(N1, N2) - dmr = 3.1415 - out1 = np.zeros_like(A) - out2 = np.zeros_like(B) - res1 = A + dmr - res2 = A + B - - sdfg, map1_entry, map2_entry = _make_sdfg() - - if mode: - map1_entry, map2_entry = map2_entry, map1_entry - - MapFusionParallel.apply_to( - sdfg, - map_entry_1=map1_entry, - map_entry_2=map2_entry, - verify=True, - ) - assert count_node(sdfg, dace.sdfg.nodes.MapEntry) == 1 - - sdfg(A=A, B=B, dmr=dmr, out1=out1, out2=out2) - assert np.allclose(out1, res1) - assert np.allclose(out2, res2) - - def test_fuse_indirect_accesses(): @dace.program(auto_optimize=False) @@ -730,6 +657,68 @@ def test_offset_correction_empty(): apply_fusion(sdfg, removed_maps=0) +def test_different_offsets(): + + def exptected(A, B): + N, M = A.shape + return (A + 1) + B[1:(N+1), 2:(M+2)] + + def _make_sdfg(N: int, M: int) -> dace.SDFG: + sdfg = dace.SDFG("test_different_access") + names = ["A", "B", "__tmp", "__return"] + def_shape = (N, M) + sshape = {"B": (N+1, M+2), "__tmp": (N+1, M+1)} + for name in names: + sdfg.add_array( + name, + shape=sshape.get(name, def_shape), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["__tmp"].transient = True + + state = sdfg.add_state(is_start_block=True) + A, B, _tmp, _return = (state.add_access(name) for name in names) + + state.add_mapped_tasklet( + "comp1", + map_ranges={"__i0": f"0:{N}", "__i1": f"0:{M}"}, + inputs={"__in": dace.Memlet("A[__i0, __i1]")}, + code="__out = __in + 1.0", + outputs={"__out": dace.Memlet("__tmp[__i0 + 1, __i1 + 1]")}, + input_nodes={A}, + output_nodes={_tmp}, + external_edges=True, + ) + state.add_mapped_tasklet( + "comp2", + map_ranges={"__i0": f"0:{N}", "__i1": f"0:{M}"}, + inputs={ + "__in1": dace.Memlet("__tmp[__i0 + 1, __i1 + 1]"), + "__in2": dace.Memlet("B[__i0 + 1, __i1 + 2]"), + }, + code="__out = __in1 + __in2", + outputs={"__out": dace.Memlet("__return[__i0, __i1]")}, + input_nodes={_tmp, B}, + output_nodes={_return}, + external_edges=True, + ) + + sdfg.validate() + return sdfg + + N, M = 14, 17 + sdfg = _make_sdfg(N, M) + apply_fusion(sdfg, final_maps=1) + + A = np.array(np.random.rand(N, M), dtype=np.float64, copy=True) + B = np.array(np.random.rand(N + 1, M + 2), dtype=np.float64, copy=True) + + ref = exptected(A, B) + res = sdfg(A=A, B=B) + assert np.allclose(ref, res) + + if __name__ == '__main__': test_indirect_accesses() test_fusion_shared() @@ -744,10 +733,10 @@ def test_offset_correction_empty(): test_fusion_with_nested_sdfg_0() test_interstate_fusion() test_fusion_with_nested_sdfg_1() - test_parallel_fusion_simple() test_fuse_indirect_accesses() test_offset_correction_range_read() test_offset_correction_scalar_read() test_offset_correction_empty() + test_different_offsets() print("SUCCESS") From 90731afba5a9df326f2c9a74f3b18444be903540 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 1 Nov 2024 10:46:23 +0100 Subject: [PATCH 084/115] Realized that I can not use `SDFG.shared_transient()` for detection if data can be removed. This is because the function is much less strict. --- dace/transformation/dataflow/map_fusion.py | 83 +++++++++++++++++++--- 1 file changed, 74 insertions(+), 9 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index eff04bd380..1efe4314ef 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -1150,14 +1150,17 @@ def is_shared_data( data: nodes.AccessNode, sdfg: dace.SDFG, ) -> bool: - """Tests if `data` is interstate data, an can not be removed. + """Tests if `data` is shared data, an can not be removed. - Interstate data is used to transmit data between multiple state or by - extension within the state. Thus it must be classified as a shared output. - This function will go through the SDFG to and collect the names of all data - container that should be classified as shared. Note that this is an over - approximation as it does not take the location into account, i.e. "is no longer - used". + Interstate data is used to transmit data, this includes: + - The data is referred in multiple states. + - The data is referred to multiple times in the same state, either the state + has multiple access nodes for that data or an access node has an + out degree larger than one. + - The data is read inside interstate edges. + + This definition is stricter than the one employed by `SDFG.shared_transients()`, + as it also includes usage within a state. Args: transient: The transient that should be checked. @@ -1166,7 +1169,7 @@ def is_shared_data( Note: The function computes the this set once for every SDFG and then caches it. There is no mechanism to detect if the cache must be evicted. However, - as long as no additional data is added, there is no problem. + as long as no additional data is added to the SDFG, there is no problem. """ if sdfg not in self._shared_data: self._compute_shared_data_in(sdfg) @@ -1184,7 +1187,69 @@ def _compute_shared_data_in( Args: sdfg: The SDFG for which the set of shared data should be computed. """ - self._shared_data[sdfg] = set(sdfg.shared_transients()) + # Shared data of this SDFG. + shared_data: Set[str] = set() + + # All global data can not be removed, so it must always be shared. + for data_name, data_desc in sdfg.arrays.items(): + if not data_desc.transient: + shared_data.add(data_name) + elif isinstance(data_desc, dace.data.Scalar): + shared_data.add(data_name) + + # We go through all states and classify the nodes/data: + # - Data is referred to in different states. + # - The access node is a view (both have to survive). + # - Transient sink or source node. + # - The access node has output degree larger than 1 (input degrees larger + # than one, will always be partitioned as shared anyway). + prevously_seen_data: Set[str] = set() + for state in sdfg.nodes(): + for access_node in state.data_nodes(): + + if access_node.data in shared_data: + # The data was already classified to be shared data + pass + + elif access_node.data in prevously_seen_data: + # We have seen this data before, either in this state or in + # a previous one, but we did not classifies it as shared back then + shared_data.add(access_node.data) + + if state.in_degree(access_node) == 0: + # (Transient) sink nodes are used in other states, or simplify + # will get rid of them. + shared_data.add(access_node.data) + + elif state.out_degree(access_node) != 1: # state.out_degree() == 0 or state.out_degree() > 1 + # The access node is either a source node (it is shared in another + # state) or the node has a degree larger than one, so it is used + # in this state somewhere else. + shared_data.add(access_node.data) + + elif self.is_view(node=access_node, sdfg=sdfg): + # To ensure that the write to the view happens, both have to be shared. + viewed_data: str = self.track_view(view=access_node, state=state, sdfg=sdfg).data + shared_data.update([access_node.data, viewed_data]) + prevously_seen_data.update([access_node.data, viewed_data]) + + else: + # The node was not classified as shared data, so we record that + # we saw it. Note that a node that was immediately classified + # as shared node will never be added to this set, but a data + # that was found twice will be inside this list. + prevously_seen_data.add(access_node.data) + + # Now we collect all symbols that are read in interstate edges. + # Because, they might refer to data inside states and must be kept alive. + interstate_read_symbols: Set[str] = set() + for edge in sdfg.edges(): + interstate_read_symbols.update(edge.data.read_symbols()) + data_read_in_interstate_edges = interstate_read_symbols.intersection(prevously_seen_data) + + # Compute the final set of shared data and update the internal cache. + shared_data.update(data_read_in_interstate_edges) + self._shared_data[sdfg] = shared_data def _compute_multi_write_data( From f659cd827d80ad71f1802ca62ce0ccc698923780 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 15 Nov 2024 13:41:45 +0100 Subject: [PATCH 085/115] Removed teh specification of the intermediate. Before the intermediate scalar was a Register, but not the array. I noticed that some rediundand array removal transformation have problems with that. An alternative would be to make the array intermediate a register. --- dace/transformation/dataflow/map_fusion.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index 1efe4314ef..98998b94ea 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -663,7 +663,6 @@ def handle_intermediate_set( new_inter_name, dtype=inter_desc.dtype, transient=True, - storage=dtypes.StorageType.Register, find_new_name=True, ) From 11a509e014ab11005da3022f50534b0a64646302 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 3 Dec 2024 14:54:18 +0100 Subject: [PATCH 086/115] Updated the strict dataflow mode. It is now made the default. Also added some tests to check if it is properly working. This mostly refactors the implementation. It should be the same as before, however, the only thing that I really removed, that could become a problem, is the check if it is used in the same state. This was always a hack, and the check should have been that it is not used in the same data flow. But I think it was redundant anyway. --- dace/transformation/dataflow/map_fusion.py | 96 +++++------------ tests/transformations/mapfusion_test.py | 115 ++++++++++++++++++++- 2 files changed, 142 insertions(+), 69 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index 98998b94ea..1e82ba181c 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -20,14 +20,19 @@ class MapFusion(transformation.SingleStateTransformation): connections appropriately. Depending on the situation the transformation will either fully remove or make the intermediate a new output of the second map. - By default, the transformation does not use the strict data flow mode. However, - it might be useful in come cases to enable it. + By default `strict_dataflow` is enabled. In this mode, the transformation + will not fuse maps that could potentially lead to a data race, because the + resulting combined map reads and writes from the same underlying data. + If strict dataflow is disabled, then the transformation might fuse such maps. + However, it will ensure that the accesses are point wise, this means that + in each iteration the map only accesses the same location that it also writes + to. Note that this could still lead to data races, because the order in which + DaCe generates the reads and writes is indeterministic. Args: only_inner_maps: Only match Maps that are internal, i.e. inside another Map. only_toplevel_maps: Only consider Maps that are at the top. - strict_dataflow: If `True`, the transformation ensures a more - stricter version of the data flow. + strict_dataflow: Which dataflow mode should be used, see above. Notes: - This transformation modifies more nodes than it matches. @@ -302,17 +307,6 @@ def partition_first_outputs( # Set of intermediate nodes that we have already processed. processed_inter_nodes: Set[nodes.Node] = set() - # These are the data that is written to multiple times in _this_ state. - # If a data is written to multiple time in a state, it could be - # classified as shared. However, it might happen that the node has zero - # degree. This is not a problem as the maps also induced a before-after - # relationship. But some DaCe transformations do not catch this. - # Thus we will never modify such intermediate nodes and fail instead. - if self.strict_dataflow: - multi_write_data: Set[str] = self._compute_multi_write_data(state, sdfg) - else: - multi_write_data = set() - # Now scan all output edges of the first exit and classify them for out_edge in state.out_edges(map_exit_1): intermediate_node: nodes.Node = out_edge.dst @@ -356,16 +350,6 @@ def partition_first_outputs( if self.is_view(intermediate_node, sdfg): return None - # Checks if the intermediate node refers to data that is accessed by - # _other_ access nodes in _this_ state. If this is the case then never - # touch this intermediate node. - # TODO(phimuell): Technically it would be enough to turn the node into - # a shared output node, because this will still fulfil the dependencies. - # However, some DaCe transformation can not handle this properly, so we - # are _forced_ to reject this node. - if intermediate_node.data in multi_write_data: - return None - # Empty Memlets are only allowed if they are in `\mathbb{P}`, which # is also the only place they really make sense (for a map exit). # Thus if we now found an empty Memlet we reject it. @@ -1033,11 +1017,6 @@ def has_read_write_dependency( if not real_write_map_1.isdisjoint(real_write_map_2): return True - # If there is no overlap in what is (totally) read and written, there will be no conflict. - # This must come before the check of disjoint write. - if (real_read_map_1 | real_read_map_2).isdisjoint(real_write_map_1 | real_write_map_2): - return False - # These are the names (unresolved) and the access nodes of the data that is used # to transmit information between the maps. The partition function ensures that # these nodes are directly connected to the two maps. @@ -1062,11 +1041,29 @@ def has_read_write_dependency( if any(self.is_view(read_map_1[name], sdfg) for name in fused_inout_data_names): return True + # In strict data flow mode we require that the input and the output of + # the fused map is distinct. + # NOTE: The code below is able to handle cases were an input to map 1 + # is also used as output of map 2. In this case the function check + # if they are point wise, i.e. every iteration reads from the same + # location it later writes to. However, even then it might cause + # problems because in which order the reads and writes are done is + # indeterministic. But if this is handled through other means, then + # it allows powerful optimizations. + if self.strict_dataflow: + if len(fused_inout_data_names) != 0: + return True + # A data container can be used as input and output. But we do not allow that - # it is also used as intermediate or exchange data. This is an important check. + # it is also used as intermediate or exchange data. if not fused_inout_data_names.isdisjoint(exchange_names): return True + # If there is no intersection between the input and output data, then we can + # we have nothing to check. + if len(fused_inout_data_names) == 0: + return False + # Get the replacement dict for changing the map variables from the subsets of # the second map. repl_dict = self.find_parameter_remapping(map_entry_1.map, map_exit_2.map) @@ -1251,43 +1248,6 @@ def _compute_shared_data_in( self._shared_data[sdfg] = shared_data - def _compute_multi_write_data( - self, - state: SDFGState, - sdfg: SDFG, - ) -> Set[str]: - """Computes data inside a _single_ state, that is written multiple times. - - Essentially this function computes the set of data that does not follow - the single static assignment idiom. The function also resolves views. - If an access node, refers to a view, not only the view itself, but also - the data it refers to is added to the set. - - Args: - state: The state that should be examined. - sdfg: The SDFG object. - - Note: - This information is used by the partition function (in case strict data - flow mode is enabled), in strict data flow mode only. The current - implementation is rather simple as it only checks if a data is written - to multiple times in the same state. - """ - data_written_to: Set[str] = set() - multi_write_data: Set[str] = set() - - for access_node in state.data_nodes(): - if state.in_degree(access_node) == 0: - continue - if access_node.data in data_written_to: - multi_write_data.add(access_node.data) - elif self.is_view(access_node, sdfg): - # This is an over approximation. - multi_write_data.update([access_node.data, self.track_view(access_node, state, sdfg).data]) - data_written_to.add(access_node.data) - return multi_write_data - - def find_parameter_remapping(self, first_map: nodes.Map, second_map: nodes.Map) -> Union[Dict[str, str], None]: """Computes the parameter remapping for the parameters of the _second_ map. diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index eb6df63183..2e19db726d 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -1,10 +1,11 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -from typing import Any, Union +from typing import Any, Union, Tuple, Optional import numpy as np import os import dace import copy +import uuid from dace import SDFG, SDFGState from dace.sdfg import nodes @@ -719,7 +720,119 @@ def _make_sdfg(N: int, M: int) -> dace.SDFG: assert np.allclose(ref, res) +def _make_strict_dataflow_sdfg_pointwise( + input_data: str = "A", + intermediate_data: str = "T", + output_data: Optional[str] = None, + input_read: str = "__i0", + output_write: Optional[str] = None, +) -> Tuple[dace.SDFG, dace.SDFGState]: + """ + Creates the SDFG for the strict data flow tests. + + The SDFG will read and write into `A`, but it is pointwise, thus the Maps can + be fused. Furthermore, this particular SDFG guarantees that no data race occurs. + """ + if output_data is None: + output_data = input_data + if output_write is None: + output_write = input_read + + sdfg = dace.SDFG(f"strict_dataflow_sdfg_pointwise_{str(uuid.uuid1()).replace('-', '_')}") + state = sdfg.add_state(is_start_block=True) + for name in {input_data, intermediate_data, output_data}: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + + if intermediate_data not in {input_data, output_data}: + sdfg.arrays[intermediate_data].transient = True + + input_node, intermediate_node, output_node = (state.add_access(name) for name in [input_data, intermediate_data, output_data]) + + state.add_mapped_tasklet( + "first_comp", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet(f"{input_data}[{input_read}]")}, + code="__out = __in1 + 2.0", + outputs={"__out": dace.Memlet(f"{intermediate_data}[__i0]")}, + input_nodes={input_node}, + output_nodes={intermediate_node}, + external_edges=True, + ) + state.add_mapped_tasklet( + "second_comp", + map_ranges={"__i1": "0:10"}, + inputs={"__in1": dace.Memlet(f"{intermediate_data}[__i1]")}, + code="__out = __in1 + 3.0", + outputs={"__out": dace.Memlet(f"{output_data}[{output_write}]")}, + input_nodes={intermediate_node}, + output_nodes={output_node}, + external_edges=True, + ) + sdfg.validate() + return sdfg, state + + +def test_fusion_strict_dataflow_pointwise(): + sdfg, state = _make_strict_dataflow_sdfg_pointwise(input_data="A") + + # Because `A` is used as input and output in strict data flow mode, + # the maps can not be fused. + count = sdfg.apply_transformations_repeated( + MapFusion(strict_dataflow=True), + validate=True, + validate_all=True, + ) + assert count == 0 + + # However, if strict dataflow is disabled, then it will be able to fuse. + count = sdfg.apply_transformations_repeated( + MapFusion(strict_dataflow=False), + validate=True, + validate_all=True, + ) + assert count == 1 + + +def test_fusion_strict_dataflow_not_pointwise(): + sdfg, state = _make_strict_dataflow_sdfg_pointwise( + input_data="A", + input_read="__i0", + output_write="9 - __i0", + ) + + # Because the dependency is not pointwise even disabling strict dataflow + # will not make it work. + count = sdfg.apply_transformations_repeated( + MapFusion(strict_dataflow=False), + validate=True, + validate_all=True, + ) + assert count == 0 + + +def test_fusion_dataflow_intermediate(): + sdfg, _ = _make_strict_dataflow_sdfg_pointwise( + input_data="A", + intermediate_data="O", + output_data="O", + ) + count = sdfg.apply_transformations_repeated( + MapFusion(strict_dataflow=True), + validate=True, + validate_all=True, + ) + assert count == 0 + + if __name__ == '__main__': + test_fusion_strict_dataflow_pointwise() + test_fusion_strict_dataflow_not_pointwise() + test_fusion_dataflow_intermediate() test_indirect_accesses() test_fusion_shared() test_fusion_with_transient() From 724d1f525bbb62ae775d368c1f99401f05105b73 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 3 Dec 2024 14:58:27 +0100 Subject: [PATCH 087/115] No longer explicitly specify strict dataflow mode. --- dace/transformation/auto/auto_optimize.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/dace/transformation/auto/auto_optimize.py b/dace/transformation/auto/auto_optimize.py index c9c665566b..7166f9e364 100644 --- a/dace/transformation/auto/auto_optimize.py +++ b/dace/transformation/auto/auto_optimize.py @@ -58,10 +58,8 @@ def greedy_fuse(graph_or_subgraph: GraphViewType, # If we have an SDFG, recurse into graphs graph_or_subgraph.simplify(validate_all=validate_all) # MapFusion for trivial cases - # We have to use `strict_dataflow` because it is known that `CompositeFusion` - # has problems otherwise. graph_or_subgraph.apply_transformations_repeated( - MapFusion(strict_dataflow=True), + MapFusion, validate_all=validate_all, ) From 10ae1e2cade1320ca4d32cf5d28731ddeb16fa41 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 4 Dec 2024 07:03:18 +0100 Subject: [PATCH 088/115] Fixed a typo that was introduced in commit e1daf32fc8. --- dace/subsets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dace/subsets.py b/dace/subsets.py index 0fdc36c22e..9c79e7d7d1 100644 --- a/dace/subsets.py +++ b/dace/subsets.py @@ -80,7 +80,7 @@ def covers(self, other): # Subsets of different dimensionality can never cover each other. if self.dims() != other.dims(): return ValueError( - f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}" + f"A subset of dimensionality {self.dims()} cannot test covering a subset of dimensionality {other.dims()}" ) if not Config.get('optimizer', 'symbolic_positive'): @@ -106,7 +106,7 @@ def covers_precise(self, other): # Subsets of different dimensionality can never cover each other. if self.dims() != other.dims(): return ValueError( - f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}" + f"A subset of dimensionality {self.dims()} cannot test covering a subset of dimensionality {other.dims()}" ) # If self does not cover other with a bounding box union, return false. From d909379636bfb0bf19d022aafadeaccd4fbe00c0 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 4 Dec 2024 07:56:00 +0100 Subject: [PATCH 089/115] Forgot to make strict dataflow the default. --- dace/transformation/dataflow/map_fusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index 1e82ba181c..8ecdb45cdf 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -68,7 +68,7 @@ class MapFusion(transformation.SingleStateTransformation): ) strict_dataflow = properties.Property( dtype=bool, - default=False, + default=True, desc="If `True` then the transformation will ensure a more stricter data flow.", ) # Maps SDFGs to the set of data that can not be removed, From cc7324bad1f0a14da556ccb855091c865354c049 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 4 Dec 2024 09:49:07 +0100 Subject: [PATCH 090/115] Added a new check to the map fusion in case of shared intermediates. It essentially check if the intermdeiate, that should be turned into a shared intermediate, i.e. a sink node to the map exit, is used in the data flow downstream. This is needed because some DaCe transformations do not correctkly check for the existence, it even seems that it is killed. --- dace/transformation/dataflow/map_fusion.py | 49 ++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index 8ecdb45cdf..652e0dede4 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -464,6 +464,21 @@ def partition_first_outputs( # output of the second map. if self.is_shared_data(intermediate_node, sdfg): # The intermediate data is used somewhere else, either in this or another state. + # NOTE: If the intermediate is shared, then we will turn it into a + # sink node attached to the combined map exit. Technically this + # should be enough, even if the same data appears again in the + # dataflow down streams. However, some DaCe transformations, + # I am looking at you `auto_optimizer()` do not like that. Thus + # if the intermediate is used further down in the same datadflow, + # then we consider that the maps can not be fused. But we only + # do this in the strict data flow mode. + if self.strict_dataflow: + if self._is_data_accessed_downstream( + data=intermediate_node.data, + graph=state, + begin=intermediate_node, # is ignored itself. + ): + return None shared_outputs.add(out_edge) else: # The intermediate can be removed, as it is not used anywhere else. @@ -1407,6 +1422,40 @@ def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: return False + def _is_data_accessed_downstream( + self, + data: str, + graph: dace.SDFGState, + begin: nodes.Node, + ) -> bool: + """Tests if there is an AccessNode for `data` downstream of `begin`. + + Essentially, this function starts a DFS at `begin` and checks every + AccessNode that is reachable from it. If it finds such a node it will + check if it refers to `data` and if so, it will return `True`. + If no such node is found it will return `False`. + Note that the node `begin` will be ignored. + + Args: + data: The name of the data to look for. + graph: The graph to explore. + begin: The node to start exploration; The node itself is ignored. + """ + def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: + return (edge.dst for edge in graph.out_edges(node)) + + # Dataflow graph is acyclic, so we do not need to keep a list of + # what we have visited. + to_visit: List[nodes.Node] = list(next_nodes(begin)) + while len(to_visit) > 0: + node = to_visit.pop() + if isinstance(node, nodes.AccessNode) and node.data == data: + return True + to_visit.extend(next_nodes(node)) + + return False + + def get_access_set( self, scope_node: Union[nodes.MapEntry, nodes.MapExit], From 6429d91a45fbf43f352828eb8b6e45a6f7a2368d Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 4 Dec 2024 09:51:46 +0100 Subject: [PATCH 091/115] Chacnged the test. This test failed due to numerical instabilities. It passed once I changed the arguments, which to me does not make sense, as I think if `a` is close to `b` then `b` should also be close to `a`. So I changed the test to an absoluet check. --- tests/npbench/polybench/correlation_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/npbench/polybench/correlation_test.py b/tests/npbench/polybench/correlation_test.py index d743ba528d..0837e43fe9 100644 --- a/tests/npbench/polybench/correlation_test.py +++ b/tests/npbench/polybench/correlation_test.py @@ -70,7 +70,7 @@ def run_correlation(device_type: dace.dtypes.DeviceType): if device_type in {dace.dtypes.DeviceType.CPU, dace.dtypes.DeviceType.GPU}: # Parse the SDFG and apply autopot - sdfg = correlation_kernel.to_sdfg() + sdfg = correlation_kernel.to_sdfg(simplify=True) sdfg = auto_optimize(sdfg, device_type) last_value = os.environ.get('DACE_testing_serialization', '0') os.environ['DACE_testing_serialization'] = '0' @@ -83,7 +83,8 @@ def run_correlation(device_type: dace.dtypes.DeviceType): # Compute ground truth and validate result corr_ref = ground_truth(M, float_n_ref, data_ref) - assert np.allclose(corr, corr_ref) + diff = corr_ref - corr + assert np.abs(diff).max() <= 10e-10 return sdfg From 3d1cd9ec257805541a539714bbe539a43a36e412 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 4 Dec 2024 11:11:20 +0100 Subject: [PATCH 092/115] Realiced that there is no problem with a data race. The reason is that everything goes through the intermediates should not be a problem. --- dace/transformation/dataflow/map_fusion.py | 38 +++++++++------------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index 652e0dede4..54ff145503 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -20,14 +20,12 @@ class MapFusion(transformation.SingleStateTransformation): connections appropriately. Depending on the situation the transformation will either fully remove or make the intermediate a new output of the second map. - By default `strict_dataflow` is enabled. In this mode, the transformation - will not fuse maps that could potentially lead to a data race, because the - resulting combined map reads and writes from the same underlying data. - If strict dataflow is disabled, then the transformation might fuse such maps. - However, it will ensure that the accesses are point wise, this means that - in each iteration the map only accesses the same location that it also writes - to. Note that this could still lead to data races, because the order in which - DaCe generates the reads and writes is indeterministic. + By default `strict_dataflow` is enabled. In this mode the transformation is + more conservative. The main difference is, that it will not adjust the + subsets of the intermediate, i.e. turning an array with shape `(1, 1, 1, 1)` + into a scalar. + Furthermore, shared intermediates, see `partition_first_outputs()` will only + be created if the data is not referred downstream in the dataflow. Args: only_inner_maps: Only match Maps that are internal, i.e. inside another Map. @@ -1050,27 +1048,21 @@ def has_read_write_dependency( # This is the names of the node that are used as input of the first map and # as output of the second map. We have to ensure that there is no data # dependency between these nodes. + # NOTE: This set is not required to be empty. It might look as this would + # create a data race, but it is save. The reason is because all data has + # to pass through the intermediate we create, this will separate the reads + # from the writes. fused_inout_data_names: Set[str] = set(read_map_1.keys()).intersection(write_map_2.keys()) # If a data container is used as input and output then it can not be a view (simplicity) if any(self.is_view(read_map_1[name], sdfg) for name in fused_inout_data_names): return True - # In strict data flow mode we require that the input and the output of - # the fused map is distinct. - # NOTE: The code below is able to handle cases were an input to map 1 - # is also used as output of map 2. In this case the function check - # if they are point wise, i.e. every iteration reads from the same - # location it later writes to. However, even then it might cause - # problems because in which order the reads and writes are done is - # indeterministic. But if this is handled through other means, then - # it allows powerful optimizations. - if self.strict_dataflow: - if len(fused_inout_data_names) != 0: - return True - - # A data container can be used as input and output. But we do not allow that - # it is also used as intermediate or exchange data. + # A data container can not be used as output (of the second as well as the + # combined map) and as intermediate. If we would allow that the map would + # have two output nodes one the original one and the second is the created + # node that is created because the intermediate is shared. + # TODO(phimuell): Handle this case. if not fused_inout_data_names.isdisjoint(exchange_names): return True From a412394d1883c032e4ecc5fc6aad0da0fffe3529 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 4 Dec 2024 11:13:08 +0100 Subject: [PATCH 093/115] Added more tests to the map fusion. --- tests/transformations/mapfusion_test.py | 80 ++++++++++++++++++++++--- 1 file changed, 71 insertions(+), 9 deletions(-) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 2e19db726d..849fdfda9f 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -780,15 +780,6 @@ def _make_strict_dataflow_sdfg_pointwise( def test_fusion_strict_dataflow_pointwise(): sdfg, state = _make_strict_dataflow_sdfg_pointwise(input_data="A") - # Because `A` is used as input and output in strict data flow mode, - # the maps can not be fused. - count = sdfg.apply_transformations_repeated( - MapFusion(strict_dataflow=True), - validate=True, - validate_all=True, - ) - assert count == 0 - # However, if strict dataflow is disabled, then it will be able to fuse. count = sdfg.apply_transformations_repeated( MapFusion(strict_dataflow=False), @@ -829,10 +820,81 @@ def test_fusion_dataflow_intermediate(): assert count == 0 +def test_fusion_dataflow_intermediate_2(): + # Because `A` is not also output transformation applies. + sdfg, state = _make_strict_dataflow_sdfg_pointwise( + input_data="A", + intermediate_data="A", + output_data="O", + ) + count = sdfg.apply_transformations_repeated( + MapFusion(strict_dataflow=True), + validate=True, + validate_all=True, + ) + assert count == 1 + map_exit = next(iter(node for node in state.nodes() if isinstance(node, nodes.MapExit))) + assert state.out_degree(map_exit) == 2 + assert {"A", "O"} == {edge.dst.data for edge in state.out_edges(map_exit) if isinstance(edge.dst, nodes.AccessNode)} + + +def test_fusion_dataflow_intermediate_downstream(): + # Because the intermediate `T` is used downstream again, + # the transformation can not apply. + sdfg, state = _make_strict_dataflow_sdfg_pointwise( + input_data="A", + intermediate_data="T", + output_data="output_1", + ) + sdfg.arrays["output_1"].transient = False + sdfg.arrays["T"].transient = True + output_1 = next(iter(dnode for dnode in state.sink_nodes())) + assert isinstance(output_1, nodes.AccessNode) and output_1.data == "output_1" + + # Make the real output node. + sdfg.arrays["O"] = sdfg.arrays["A"].clone() + state.add_mapped_tasklet( + "downstream_computation", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("output_1[__i0]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("T[__i0]")}, + input_nodes={output_1}, + external_edges=True, + ) + sdfg.validate() + + count = sdfg.apply_transformations_repeated( + MapFusion(strict_dataflow=True), + validate=True, + validate_all=True, + ) + assert count == 0 + + # However without strict dataflow, the merge is possible. + count = sdfg.apply_transformations_repeated( + MapFusion(strict_dataflow=False), + validate=True, + validate_all=True, + ) + assert count == 1 + assert state.in_degree(output_1) == 1 + assert state.out_degree(output_1) == 1 + assert all(isinstance(edge.src, nodes.MapExit) for edge in state.in_edges(output_1)) + assert all(isinstance(edge.dst, nodes.MapEntry) for edge in state.out_edges(output_1)) + + upper_map_exit = next(iter(edge.src for edge in state.in_edges(output_1))) + assert isinstance(upper_map_exit, nodes.MapExit) + assert state.out_degree(upper_map_exit) == 2 + assert {"T", "output_1"} == {edge.dst.data for edge in state.out_edges(upper_map_exit) if isinstance(edge.dst, nodes.AccessNode)} + + if __name__ == '__main__': test_fusion_strict_dataflow_pointwise() test_fusion_strict_dataflow_not_pointwise() test_fusion_dataflow_intermediate() + test_fusion_dataflow_intermediate_2() + test_fusion_dataflow_intermediate_downstream() test_indirect_accesses() test_fusion_shared() test_fusion_with_transient() From fa21bd3bc4798792452780c813e18fa22d13b670 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 5 Dec 2024 11:42:16 +0100 Subject: [PATCH 094/115] Added a new test for the map fusion. It mainly verifies that fusion stop at a certain invalid point. --- tests/transformations/mapfusion_test.py | 89 +++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 849fdfda9f..6a257cf13e 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -889,7 +889,96 @@ def test_fusion_dataflow_intermediate_downstream(): assert {"T", "output_1"} == {edge.dst.data for edge in state.out_edges(upper_map_exit) if isinstance(edge.dst, nodes.AccessNode)} +def test_fusion_non_strict_dataflow_implicit_dependency(): + """ + This test simulates if the fusion respect implicit dependencies, given by access nodes. + + This test simulates a situation that could arise if non strict dataflow is enabled. + The test ensures that the fusion does not continue fusing in this situation. + """ + sdfg = dace.SDFG("fusion_strict_dataflow_implicit_dependency_sdfg") + state = sdfg.add_state(is_start_block=True) + names = ["A", "B", "T1", "T2", "C"] + for name in names: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["T1"].transient = True + sdfg.arrays["T2"].transient = True + + me, mx = state.add_map( + "first_map", + ndrange={"__i0": "0:10"} + ) + tskl1 = state.add_tasklet( + "tskl1", + inputs={"__in1", "__in2"}, + code="__out = __in1 * __in2", + outputs={"__out"} + ) + tskl2 = state.add_tasklet( + "tskl2", + inputs={"__in1", "__in2"}, + code="__out = (__in1 + __in2) / 2", + outputs={"__out"} + ) + A, B, T1, T2 = (state.add_access(name) for name in names[:-1]) + + state.add_edge(A, None, me, "IN_A", dace.Memlet("A[0:10]")) + state.add_edge(B, None, me, "IN_B", dace.Memlet("B[0:10]")) + me.add_in_connector("IN_A") + me.add_in_connector("IN_B") + + state.add_edge(me, "OUT_A", tskl1, "__in1", dace.Memlet("A[__i0]")) + state.add_edge(me, "OUT_B", tskl1, "__in2", dace.Memlet("B[__i0]")) + state.add_edge(me, "OUT_A", tskl2, "__in1", dace.Memlet("A[__i0]")) + state.add_edge(me, "OUT_B", tskl2, "__in2", dace.Memlet("B[__i0]")) + me.add_out_connector("OUT_A") + me.add_out_connector("OUT_B") + + state.add_edge(tskl1, "__out", mx, "IN_T1", dace.Memlet("T1[__i0]")) + state.add_edge(tskl2, "__out", mx, "IN_T2", dace.Memlet("T2[__i0]")) + mx.add_in_connector("IN_T1") + mx.add_in_connector("IN_T2") + + state.add_edge(mx, "OUT_T1", T1, None, dace.Memlet("T1[0:10]")) + state.add_edge(mx, "OUT_T2", T2, None, dace.Memlet("T2[0:10]")) + mx.add_out_connector("OUT_T1") + mx.add_out_connector("OUT_T2") + + state.add_mapped_tasklet( + "second_map", + map_ranges={"__in0": "0:10"}, + inputs={"__in1": dace.Memlet("T1[__i0]")}, + code="if __in1 < 0.5:\n\t__out = 100.", + outputs={"__out": dace.Memlet("T2[__i0]", dynamic=True)}, + input_nodes={T1}, + external_edges=True, + ) + + state2 = sdfg.add_state_after(state) + state2.add_edge( + state2.add_access("T2"), + None, + state2.add_access("C"), + None, + dace.Memlet("T2[0:10] -> [0:10]"), + ) + sdfg.validate() + + count = sdfg.apply_transformations_repeated( + MapFusion(strict_dataflow=False), + validate=True, + validate_all=True, + ) + assert count == 0 + + if __name__ == '__main__': + test_fusion_non_strict_dataflow_implicit_dependency() test_fusion_strict_dataflow_pointwise() test_fusion_strict_dataflow_not_pointwise() test_fusion_dataflow_intermediate() From d8da3c6319d2122180e3ef88d159866b389c33a3 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 13 Dec 2024 09:44:14 +0100 Subject: [PATCH 095/115] WIP: Started with implementing Phil's suggestions. --- dace/transformation/auto/auto_optimize.py | 2 +- dace/transformation/dataflow/buffer_tiling.py | 6 +- dace/transformation/dataflow/map_fusion.py | 368 +++++++++++------- dace/transformation/dataflow/mapreduce.py | 4 +- tests/npbench/polybench/correlation_test.py | 2 +- tests/transformations/apply_to_test.py | 6 +- .../mapfusion_data_races_test.py | 1 - tests/transformations/mapfusion_test.py | 98 ++++- 8 files changed, 329 insertions(+), 158 deletions(-) diff --git a/dace/transformation/auto/auto_optimize.py b/dace/transformation/auto/auto_optimize.py index 7166f9e364..1fc11a076f 100644 --- a/dace/transformation/auto/auto_optimize.py +++ b/dace/transformation/auto/auto_optimize.py @@ -59,7 +59,7 @@ def greedy_fuse(graph_or_subgraph: GraphViewType, graph_or_subgraph.simplify(validate_all=validate_all) # MapFusion for trivial cases graph_or_subgraph.apply_transformations_repeated( - MapFusion, + MapFusion(strict_dataflow=True), validate_all=validate_all, ) diff --git a/dace/transformation/dataflow/buffer_tiling.py b/dace/transformation/dataflow/buffer_tiling.py index 6fac761175..af966d8a32 100644 --- a/dace/transformation/dataflow/buffer_tiling.py +++ b/dace/transformation/dataflow/buffer_tiling.py @@ -100,9 +100,9 @@ def apply(self, graph, sdfg): some_buffer = next(iter(buffers)) # some dummy to pass to MapFusion.apply_to() MapFusion.apply_to( sdfg, - map_exit_1=tile_map1_exit, - intermediate_access_node=some_buffer, - map_entry_2=tile_map2_entry, + first_map_exit=tile_map1_exit, + array=some_buffer, + second_map_entry=tile_map2_entry, verify=True, ) diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index 54ff145503..b080ae18d3 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -48,9 +48,9 @@ class MapFusion(transformation.SingleStateTransformation): """ # Pattern Nodes - map_exit_1 = transformation.transformation.PatternNode(nodes.MapExit) - intermediate_access_node = transformation.transformation.PatternNode(nodes.AccessNode) - map_entry_2 = transformation.transformation.PatternNode(nodes.MapEntry) + first_map_exit = transformation.transformation.PatternNode(nodes.MapExit) + array = transformation.transformation.PatternNode(nodes.AccessNode) + second_map_entry = transformation.transformation.PatternNode(nodes.MapEntry) # Settings @@ -84,11 +84,11 @@ def __init__( ) -> None: super().__init__(**kwargs) if only_toplevel_maps is not None: - self.only_toplevel_maps = bool(only_toplevel_maps) + self.only_toplevel_maps = only_toplevel_maps if only_inner_maps is not None: - self.only_inner_maps = bool(only_inner_maps) + self.only_inner_maps = only_inner_maps if strict_dataflow is not None: - self.strict_dataflow = bool(strict_dataflow) + self.strict_dataflow = strict_dataflow self._shared_data = {} @@ -102,7 +102,7 @@ def expressions(cls) -> Any: matched nodes, but more or less on anything that has an incoming connection from the first Map or an outgoing connection to the second Map entry. """ - return [dace.sdfg.utils.node_path_graph(cls.map_exit_1, cls.intermediate_access_node, cls.map_entry_2)] + return [dace.sdfg.utils.node_path_graph(cls.first_map_exit, cls.array, cls.second_map_entry)] def can_be_applied( @@ -119,18 +119,40 @@ def can_be_applied( - Tests if there are read write dependencies. - Tests if the decomposition exists. """ - map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit_1) - map_exit_1: nodes.MapExit = self.map_exit_1 - map_entry_2: nodes.MapEntry = self.map_entry_2 + first_map_entry: nodes.MapEntry = graph.entry_node(self.first_map_exit) + first_map_exit: nodes.MapExit = self.first_map_exit + second_map_entry: nodes.MapEntry = self.second_map_entry + + # Check the structural properties of the Maps. The function will return + # the `dict` that describes how the parameters must be renamed (for caching) + # or `None` if the maps can not be structurally fused. + param_repl = self.can_topologically_be_fused( + first_map_entry=first_map_entry, + second_map_entry=second_map_entry, + graph=graph, + sdfg=sdfg + ) + if param_repl is None: + return False - # This essentially test the structural properties of the two Maps. - if not self.can_topologically_be_fused(map_entry_1=map_entry_1, map_entry_2=map_entry_2, graph=graph, sdfg=sdfg): + # Tests if there are read write dependencies that are caused by the bodies + # of the Maps, such as referring to the same data. Note that this tests are + # different from the ones performed by `has_read_write_dependency()`, which + # only checks the data dependencies that go through the scope nodes. + if self.has_inner_read_write_dependency( + first_map_entry=first_map_entry, + second_map_entry=second_map_entry, + state=graph, + sdfg=sdfg, + ): return False - # Test for read-write conflicts + # Tests for read write conflicts of the two maps, this is only checking + # the data that goes through the scope nodes. `has_inner_read_write_dependency()` + # if used to check if there are internal dependencies. if self.has_read_write_dependency( - map_entry_1=map_entry_1, - map_entry_2=map_entry_2, + first_map_entry=first_map_entry, + second_map_entry=second_map_entry, state=graph, sdfg=sdfg, ): @@ -142,8 +164,9 @@ def can_be_applied( output_partition = self.partition_first_outputs( state=graph, sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, + first_map_exit=first_map_exit, + second_map_entry=second_map_entry, + param_repl=param_repl, ) if output_partition is None: return False @@ -170,27 +193,30 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non # once we start adding and removing nodes it seems that their ID changes. # Thus we have to save them here, this is a known behaviour in DaCe. assert isinstance(graph, dace.SDFGState) - assert isinstance(self.map_exit_1, nodes.MapExit) - assert isinstance(self.map_entry_2, nodes.MapEntry) + assert isinstance(self.first_map_exit, nodes.MapExit) + assert isinstance(self.second_map_entry, nodes.MapEntry) - map_exit_1: nodes.MapExit = self.map_exit_1 - map_entry_2: nodes.MapEntry = self.map_entry_2 - map_exit_2: nodes.MapExit = graph.exit_node(self.map_entry_2) - map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit_1) + first_map_exit: nodes.MapExit = self.first_map_exit + second_map_entry: nodes.MapEntry = self.second_map_entry + second_map_exit: nodes.MapExit = graph.exit_node(self.second_map_entry) + first_map_entry: nodes.MapEntry = graph.entry_node(self.first_map_exit) # Before we do anything we perform the renaming. self.rename_map_parameters( - first_map=map_exit_1.map, - second_map=map_entry_2.map, - second_map_entry=map_entry_2, + first_map=first_map_exit.map, + second_map=second_map_entry.map, + second_map_entry=second_map_entry, state=graph, ) + # Now compute the partition. Because we have already renamed the parameters + # of the second Map, there is no need to perform any renaming. output_partition = self.partition_first_outputs( state=graph, sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, + first_map_exit=first_map_exit, + second_map_entry=second_map_entry, + param_repl=dict(), ) assert output_partition is not None # Make MyPy happy. pure_outputs, exclusive_outputs, shared_outputs = output_partition @@ -200,9 +226,9 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non intermediate_outputs=exclusive_outputs, state=graph, sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, - map_exit_2=map_exit_2, + first_map_exit=first_map_exit, + second_map_entry=second_map_entry, + second_map_exit=second_map_exit, is_exclusive_set=True, ) if len(shared_outputs) != 0: @@ -210,16 +236,16 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non intermediate_outputs=shared_outputs, state=graph, sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, - map_exit_2=map_exit_2, + first_map_exit=first_map_exit, + second_map_entry=second_map_entry, + second_map_exit=second_map_exit, is_exclusive_set=False, ) - assert pure_outputs == set(graph.out_edges(map_exit_1)) + assert pure_outputs == set(graph.out_edges(first_map_exit)) if len(pure_outputs) != 0: self.relocate_nodes( - from_node=map_exit_1, - to_node=map_exit_2, + from_node=first_map_exit, + to_node=second_map_exit, state=graph, sdfg=sdfg, ) @@ -228,26 +254,27 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non # to the first map, now we must move the output of the first map # to the second one, as this one is used. self.relocate_nodes( - from_node=map_entry_2, - to_node=map_entry_1, + from_node=second_map_entry, + to_node=first_map_entry, state=graph, sdfg=sdfg, ) - for node_to_remove in [map_exit_1, map_entry_2]: + for node_to_remove in [first_map_exit, second_map_entry]: assert graph.degree(node_to_remove) == 0 graph.remove_node(node_to_remove) # Now turn the second output node into the output node of the first Map. - map_exit_2.map = map_entry_1.map + second_map_exit.map = first_map_entry.map def partition_first_outputs( self, state: SDFGState, sdfg: SDFG, - map_exit_1: nodes.MapExit, - map_entry_2: nodes.MapEntry, + first_map_exit: nodes.MapExit, + second_map_entry: nodes.MapEntry, + param_repl: Dict[str, str], ) -> Union[ Tuple[ Set[graph.MultiConnectorEdge[dace.Memlet]], @@ -256,7 +283,7 @@ def partition_first_outputs( ], None, ]: - """Partition the output edges of `map_exit_1` for serial map fusion. + """Partition the output edges of `first_map_exit` for serial map fusion. The output edges of the first map are partitioned into three distinct sets, defined as follows: @@ -277,36 +304,27 @@ def partition_first_outputs( output can be added to either intermediate set and might fail to compute the partition, even if it would exist. - Returns: - If such a decomposition exists the function will return the three sets - mentioned above in the same order. - In case the decomposition does not exist, i.e. the maps can not be fused - the function returns `None`. + :return: If such a decomposition exists the function will return the three sets + mentioned above in the same order. In case the decomposition does not exist, + i.e. the maps can not be fused the function returns `None`. - Args: - state: The in which the two maps are located. - sdfg: The full SDFG in whcih we operate. - map_exit_1: The exit node of the first map. - map_entry_2: The entry node of the second map. + :param state: The in which the two maps are located. + :param sdfg: The full SDFG in whcih we operate. + :param first_map_exit: The exit node of the first map. + :param second_map_entry: The entry node of the second map. + :param param_repl: Use this map to rename the parameter of the second Map, such + that they match the one of the first map. """ # The three outputs set. pure_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() exclusive_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() shared_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() - # Compute the renaming that for translating the parameter of the _second_ - # map to the ones used by the first map. - repl_dict: Dict[str, str] = self.find_parameter_remapping( - first_map=map_exit_1.map, - second_map=map_entry_2.map, - ) - assert repl_dict is not None - # Set of intermediate nodes that we have already processed. processed_inter_nodes: Set[nodes.Node] = set() # Now scan all output edges of the first exit and classify them - for out_edge in state.out_edges(map_exit_1): + for out_edge in state.out_edges(first_map_exit): intermediate_node: nodes.Node = out_edge.dst # We already processed the node, this should indicate that we should @@ -329,7 +347,7 @@ def partition_first_outputs( if not self.is_node_reachable_from( graph=state, begin=intermediate_node, - end=map_entry_2, + end=second_map_entry, ): pure_outputs.add(out_edge) continue @@ -361,7 +379,7 @@ def partition_first_outputs( # To handle this we need to associate a consumer edge (the outgoing edges # of the second map) with exactly one producer. producer_edges: List[graph.MultiConnectorEdge[dace.Memlet]] = list( - state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:])) + state.in_edges_by_connector(first_map_exit, "IN_" + out_edge.src_conn[4:])) if len(producer_edges) > 1: return None @@ -410,8 +428,8 @@ def partition_first_outputs( # If the second map entry is not immediately reachable from the intermediate # node, then ensure that there is not path that goes to it. - if intermediate_node_out_edge.dst is not map_entry_2: - if self.is_node_reachable_from(graph=state, begin=intermediate_node_out_edge.dst, end=map_entry_2): + if intermediate_node_out_edge.dst is not second_map_entry: + if self.is_node_reachable_from(graph=state, begin=intermediate_node_out_edge.dst, end=second_map_entry): return None continue @@ -431,7 +449,7 @@ def partition_first_outputs( # We do not check them, but collect them and inspect them. # NOTE: The subset still uses the old iteration variables. for inner_consumer_edge in state.out_edges_by_connector( - map_entry_2, "OUT_" + intermediate_node_out_edge.dst_conn[3:]): + second_map_entry, "OUT_" + intermediate_node_out_edge.dst_conn[3:]): if inner_consumer_edge.data.src_subset is None: return None if inner_consumer_edge.data.dynamic: @@ -442,10 +460,10 @@ def partition_first_outputs( assert len(consumer_subsets) != 0 # The consumer still uses the original symbols of the second map, so we must rename them. - if repl_dict: + if param_repl: consumer_subsets = copy.deepcopy(consumer_subsets) for consumer_subset in consumer_subsets: - symbolic.safe_replace(mapping=repl_dict, replace_callback=consumer_subset.replace) + symbolic.safe_replace(mapping=param_repl, replace_callback=consumer_subset.replace) # Now we are checking if a single iteration of the first (top) map # can satisfy all data requirements of the second (bottom) map. @@ -582,9 +600,9 @@ def handle_intermediate_set( intermediate_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]], state: SDFGState, sdfg: SDFG, - map_exit_1: nodes.MapExit, - map_entry_2: nodes.MapEntry, - map_exit_2: nodes.MapExit, + first_map_exit: nodes.MapExit, + second_map_entry: nodes.MapEntry, + second_map_exit: nodes.MapExit, is_exclusive_set: bool, ) -> None: """This function handles the intermediate sets. @@ -599,9 +617,9 @@ def handle_intermediate_set( intermediate_outputs: The set of outputs, that should be processed. state: The state in which the map is processed. sdfg: The SDFG that should be optimized. - map_exit_1: The exit of the first/top map. - map_entry_2: The entry of the second map. - map_exit_2: The exit of the second map. + first_map_exit: The exit of the first/top map. + second_map_entry: The entry of the second map. + second_map_exit: The exit of the second map. is_exclusive_set: If `True` `intermediate_outputs` is the exclusive set. Notes: @@ -609,7 +627,7 @@ def handle_intermediate_set( after this function has run the state is (most likely) invalid. """ - map_params = map_exit_1.map.params.copy() + map_params = first_map_exit.map.params.copy() # Now we will iterate over all intermediate edges and process them. # If not stated otherwise the comments assume that we run in exclusive mode. @@ -623,7 +641,7 @@ def handle_intermediate_set( # Now we will determine the shape of the new intermediate. This size of # this temporary is given by the Memlet that goes into the first map exit. - pre_exit_edges = list(state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:])) + pre_exit_edges = list(state.in_edges_by_connector(first_map_exit, "IN_" + out_edge.src_conn[4:])) if len(pre_exit_edges) != 1: raise NotImplementedError() pre_exit_edge = pre_exit_edges[0] @@ -732,7 +750,7 @@ def handle_intermediate_set( # NOTE: Assumes that map (if connected is the direct neighbour). conn_names: Set[str] = set() for inter_node_out_edge in state.out_edges(inter_node): - if inter_node_out_edge.dst == map_entry_2: + if inter_node_out_edge.dst == second_map_entry: assert inter_node_out_edge.dst_conn.startswith("IN_") conn_names.add(inter_node_out_edge.dst_conn) else: @@ -747,7 +765,7 @@ def handle_intermediate_set( for in_conn_name in conn_names: out_conn_name = "OUT_" + in_conn_name[3:] - for inner_edge in state.out_edges_by_connector(map_entry_2, out_conn_name): + for inner_edge in state.out_edges_by_connector(second_map_entry, out_conn_name): assert inner_edge.data.data == inter_name # DIRECTION!! # As for the producer side, we now read from a smaller array, @@ -798,11 +816,11 @@ def handle_intermediate_set( # The edge that leaves the second map entry was already deleted. We now delete # the edges that connected the intermediate node with the second map entry. - for edge in list(state.in_edges_by_connector(map_entry_2, in_conn_name)): + for edge in list(state.in_edges_by_connector(second_map_entry, in_conn_name)): assert edge.src == inter_node state.remove_edge(edge) - map_entry_2.remove_in_connector(in_conn_name) - map_entry_2.remove_out_connector(out_conn_name) + second_map_entry.remove_in_connector(in_conn_name) + second_map_entry.remove_out_connector(out_conn_name) if is_exclusive_set: # In exclusive mode the old intermediate node is no longer needed. @@ -812,15 +830,15 @@ def handle_intermediate_set( state.remove_node(inter_node) state.remove_edge(pre_exit_edge) - map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) - map_exit_1.remove_out_connector(out_edge.src_conn) + first_map_exit.remove_in_connector(pre_exit_edge.dst_conn) + first_map_exit.remove_out_connector(out_edge.src_conn) del sdfg.arrays[inter_name] else: # This is the shared mode, so we have to recreate the intermediate # node, but this time it is at the exit of the second map. state.remove_edge(pre_exit_edge) - map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) + first_map_exit.remove_in_connector(pre_exit_edge.dst_conn) # This is the Memlet that goes from the map internal intermediate # temporary node to the Map output. This will essentially restore @@ -830,25 +848,25 @@ def handle_intermediate_set( assert pre_exit_edge.data.data == inter_name final_pre_exit_memlet.other_subset = subsets.Range.from_array(new_inter_desc) - new_pre_exit_conn = map_exit_2.next_connector() + new_pre_exit_conn = second_map_exit.next_connector() state.add_edge( new_inter_node, None, - map_exit_2, + second_map_exit, "IN_" + new_pre_exit_conn, final_pre_exit_memlet, ) state.add_edge( - map_exit_2, + second_map_exit, "OUT_" + new_pre_exit_conn, inter_node, out_edge.dst_conn, copy.deepcopy(out_edge.data), ) - map_exit_2.add_in_connector("IN_" + new_pre_exit_conn) - map_exit_2.add_out_connector("OUT_" + new_pre_exit_conn) + second_map_exit.add_in_connector("IN_" + new_pre_exit_conn) + second_map_exit.add_out_connector("OUT_" + new_pre_exit_conn) - map_exit_1.remove_out_connector(out_edge.src_conn) + first_map_exit.remove_out_connector(out_edge.src_conn) state.remove_edge(out_edge) @@ -919,12 +937,12 @@ def compute_offset_subset( def can_topologically_be_fused( self, - map_entry_1: nodes.MapEntry, - map_entry_2: nodes.MapEntry, + first_map_entry: nodes.MapEntry, + second_map_entry: nodes.MapEntry, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG, permissive: bool = False, - ) -> bool: + ) -> Optional[Dict[str, str]]: """Performs basic checks if the maps can be fused. This function only checks constrains that are common between serial and @@ -933,42 +951,108 @@ def can_topologically_be_fused( - The scheduling of the maps. - The map parameters. - Args: - map_entry_1: The entry of the first (in serial case the top) map. - map_exit_2: The entry of the second (in serial case the bottom) map. - graph: The SDFGState in which the maps are located. - sdfg: The SDFG itself. - permissive: Currently unused. + :return: If the maps can not be topologically fused the function returns `None`. + If they can be fused the function returns `dict` that describes parameter + replacement, see `find_parameter_remapping()` for more. + + :param first_map_entry: The entry of the first (in serial case the top) map. + :param second_map_exit: The entry of the second (in serial case the bottom) map. + :param graph: The SDFGState in which the maps are located. + :param sdfg: The SDFG itself. + :param permissive: Currently unused. """ if self.only_inner_maps and self.only_toplevel_maps: - raise ValueError("You specified both `only_inner_maps` and `only_toplevel_maps`.") + raise ValueError("Only one of `only_inner_maps` and `only_toplevel_maps` is allowed per MapFusion instance.") # Ensure that both have the same schedule - if map_entry_1.map.schedule != map_entry_2.map.schedule: - return False + if first_map_entry.map.schedule != second_map_entry.map.schedule: + return None # Fusing is only possible if the two entries are in the same scope. scope = graph.scope_dict() - if scope[map_entry_1] != scope[map_entry_2]: - return False + if scope[first_map_entry] != scope[second_map_entry]: + return None elif self.only_inner_maps: - if scope[map_entry_1] is None: - return False + if scope[first_map_entry] is None: + return None elif self.only_toplevel_maps: - if scope[map_entry_1] is not None: - return False + if scope[first_map_entry] is not None: + return None # We will now check if there exists a remapping that of the map parameter - if self.find_parameter_remapping(first_map=map_entry_1.map, second_map=map_entry_2.map) is None: - return False + param_repl = self.find_parameter_remapping(first_map=first_map_entry.map, second_map=second_map_entry.map) + if param_repl is None: + return None + return None - return True + + def has_inner_read_write_dependency( + self, + first_map_entry: nodes.MapEntry, + second_map_entry: nodes.MapEntry, + state: SDFGState, + sdfg: SDFG, + ) -> bool: + """This function tests if there are dependency inside the Maps. + + The function will scan and anaysize the body of the two Maps and look for + inconsistencies. To detect them the function will scan the body of the maps + and examine the all AccessNodes and apply the following rules: + - If an AccessNode refers to a View, it is ignored. Because the source is + either on the outside, in which case `has_read_write_dependency()` + takes care of it, or the data source is inside the Map body itself. + - An inconsistency is detected, if in each bodies there exists an AccessNode + that refer to the same data. + - An inconsistency is detected, if there exists an AccessNode that refers + to non transient data. This is an implementation detail and could be + lifted. + + Note that some of the restrictions of this function could be relaxed by + performing more analysis. + + :return: The function returns `True` if an inconsistency has been found. + + :param first_map_entry: The entry node of the first map. + :param second_map_entry: The entry node of the second map. + :param state: The state on which we operate. + :param sdfg: The SDFG on which we operate. + """ + first_map_body = state.scope_subgraph(first_map_entry, False, False) + second_map_body = state.scope_subgraph(second_map_entry, False, False) + + # Find the data that is internally referenced. Because of the first rule above, + # we filter all views above. + first_map_body_data, second_map_body_data = [ + { + dnode.data + for dnode in map_body.nodes() + if isinstance(dnode, nodes.AccessNode) and not self.is_view(dnode, sdfg) + } + for map_body in [first_map_body, second_map_body] + ] + + # If there is data that is referenced in both, then we consider this as an error + # this is the second rule above. + if not first_map_body_data.isdisjoint(second_map_body_data): + return True + + # We consider it as a problem if any map refers to non-transient data. + # This is an implementation detail and could be dropped if we do further + # analysis. + if any( + not sdfg.arrays[data].transient + for data in first_map_body_data.union(second_map_body_data) + ): + return True + + return False def has_read_write_dependency( self, - map_entry_1: nodes.MapEntry, - map_entry_2: nodes.MapEntry, + first_map_entry: nodes.MapEntry, + second_map_entry: nodes.MapEntry, + param_repl: Dict[str, str], state: SDFGState, sdfg: SDFG, ) -> bool: @@ -983,24 +1067,22 @@ def has_read_write_dependency( at the same time. However, the function will not check for read write conflicts in this set, this is done in the partition function. - Returns: - `True` if there is a conflict between the maps that can not be handled. - If there is no conflict or if the conflict can be handled `False` - is returned. + :return: `True` if there is a conflict between the maps that can not be handled. + If there is no conflict or if the conflict can be handled `False` is returned. - Args: - map_entry_1: The entry node of the first map. - map_entry_2: The entry node of the second map. - state: The state on which we operate. - sdfg: The SDFG on which we operate. + :param first_map_entry: The entry node of the first map. + :param second_map_entry: The entry node of the second map. + :param param_repl: Dict that describes how to rename the parameters of the second Map. + :param state: The state on which we operate. + :param sdfg: The SDFG on which we operate. """ - map_exit_1: nodes.MapExit = state.exit_node(map_entry_1) - map_exit_2: nodes.MapExit = state.exit_node(map_entry_2) + first_map_exit: nodes.MapExit = state.exit_node(first_map_entry) + second_map_exit: nodes.MapExit = state.exit_node(second_map_entry) # Get the read and write sets of the different maps, note that Views # are not resolved yet. access_sets: List[Dict[str, nodes.AccessNode]] = [] - for scope_node in [map_entry_1, map_exit_1, map_entry_2, map_exit_2]: + for scope_node in [first_map_entry, first_map_exit, second_map_entry, second_map_exit]: access_set: Set[nodes.AccessNode] = self.get_access_set(scope_node, state) access_sets.append({node.data: node for node in access_set}) # If two different access nodes of the same scoping node refers to the @@ -1071,10 +1153,6 @@ def has_read_write_dependency( if len(fused_inout_data_names) == 0: return False - # Get the replacement dict for changing the map variables from the subsets of - # the second map. - repl_dict = self.find_parameter_remapping(map_entry_1.map, map_exit_2.map) - # Now we inspect if there is a read write dependency, between data that is # used as input and output of the fused map. There is no problem is they # are pointwise, i.e. in each iteration the same locations are accessed. @@ -1085,20 +1163,20 @@ def has_read_write_dependency( all_subsets.extend( self.find_subsets( node=read_map_1[inout_data_name], - scope_node=map_entry_1, + scope_node=first_map_entry, state=state, sdfg=sdfg, - repl_dict=None, + param_repl=None, )) # While the subsets defining writing are given by the second map's exit # node, there we also have to apply renaming. all_subsets.extend( self.find_subsets( node=write_map_2[inout_data_name], - scope_node=map_exit_2, + scope_node=second_map_exit, state=state, sdfg=sdfg, - repl_dict=repl_dict, + param_repl=param_repl, )) # Now we can test if these subsets are point wise if not self.test_if_subsets_are_point_wise(all_subsets): @@ -1255,7 +1333,7 @@ def _compute_shared_data_in( self._shared_data[sdfg] = shared_data - def find_parameter_remapping(self, first_map: nodes.Map, second_map: nodes.Map) -> Union[Dict[str, str], None]: + def find_parameter_remapping(self, first_map: nodes.Map, second_map: nodes.Map) -> Optional[Dict[str, str]]: """Computes the parameter remapping for the parameters of the _second_ map. The returned `dict` maps the parameters of the second map (keys) to parameter @@ -1486,7 +1564,7 @@ def find_subsets( scope_node: Union[nodes.MapExit, nodes.MapEntry], state: SDFGState, sdfg: SDFG, - repl_dict: Optional[Dict[str, str]], + param_repl: Optional[Dict[str, str]], ) -> List[subsets.Subset]: """Finds all subsets that access `node` within `scope_node`. @@ -1494,13 +1572,13 @@ def find_subsets( Instead it will locate the edges which is immediately inside the map scope. - Args: - node: The access node that should be examined. - scope_node: We are only interested in data that flows through this node. - state: The state in which we operate. - sdfg: The SDFG object. + :param node: The access node that should be examined. + :param scope_node: We are only interested in data that flows through this node. + :param state: The state in which we operate. + :param sdfg: The SDFG object. + :param param_repl: `dict` that describes the parameter renaming that should be + performed. Can be `None` to skip the processing. """ - # Is the node used for reading or for writing. # This influences how we have to proceed. if isinstance(scope_node, nodes.MapEntry): @@ -1519,10 +1597,10 @@ def find_subsets( assert not any(subset is None for subset in found_subsets) found_subsets = copy.deepcopy(found_subsets) - if repl_dict: + if param_repl: for subset in found_subsets: # Replace happens in place - symbolic.safe_replace(repl_dict, subset.replace) + symbolic.safe_replace(param_repl, subset.replace) return found_subsets diff --git a/dace/transformation/dataflow/mapreduce.py b/dace/transformation/dataflow/mapreduce.py index de76eee0ba..0eef39c3cb 100644 --- a/dace/transformation/dataflow/mapreduce.py +++ b/dace/transformation/dataflow/mapreduce.py @@ -217,7 +217,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG): map_fusion = MapFusion() map_fusion.setup_match(sdfg, self.cfg_id, self.state_id, { - MapFusion.map_exit_1: graph.node_id(self.tmap_exit), - MapFusion.map_entry_2: graph.node_id(map_entry), + MapFusion.first_map_exit: graph.node_id(self.tmap_exit), + MapFusion.second_map_entry: graph.node_id(map_entry), }, 0) map_fusion.apply(graph, sdfg) diff --git a/tests/npbench/polybench/correlation_test.py b/tests/npbench/polybench/correlation_test.py index 0837e43fe9..a5532cf829 100644 --- a/tests/npbench/polybench/correlation_test.py +++ b/tests/npbench/polybench/correlation_test.py @@ -70,7 +70,7 @@ def run_correlation(device_type: dace.dtypes.DeviceType): if device_type in {dace.dtypes.DeviceType.CPU, dace.dtypes.DeviceType.GPU}: # Parse the SDFG and apply autopot - sdfg = correlation_kernel.to_sdfg(simplify=True) + sdfg = correlation_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) last_value = os.environ.get('DACE_testing_serialization', '0') os.environ['DACE_testing_serialization'] = '0' diff --git a/tests/transformations/apply_to_test.py b/tests/transformations/apply_to_test.py index c410c7a552..30ccb53f28 100644 --- a/tests/transformations/apply_to_test.py +++ b/tests/transformations/apply_to_test.py @@ -36,9 +36,9 @@ def test_applyto_enumerate(): pattern = sdutil.node_path_graph(dace.nodes.MapExit, dace.nodes.AccessNode, dace.nodes.MapEntry) for subgraph in enumerate_matches(sdfg, pattern): MapFusion.apply_to(sdfg, - map_exit_1=subgraph.source_nodes()[0], - intermediate_access_node=next(n for n in subgraph.nodes() if isinstance(n, dace.nodes.AccessNode)), - map_entry_2=subgraph.sink_nodes()[0]) + first_map_exit=subgraph.source_nodes()[0], + array=next(n for n in subgraph.nodes() if isinstance(n, dace.nodes.AccessNode)), + second_map_entry=subgraph.sink_nodes()[0]) def test_applyto_pattern(): diff --git a/tests/transformations/mapfusion_data_races_test.py b/tests/transformations/mapfusion_data_races_test.py index 0466a32551..ff87fd61ec 100644 --- a/tests/transformations/mapfusion_data_races_test.py +++ b/tests/transformations/mapfusion_data_races_test.py @@ -101,4 +101,3 @@ def test_rw_data_race_4_mf(): test_rw_data_race_3_sgf() test_rw_data_race_3_mf() test_rw_data_race_4_mf() - print("SUCCESS") diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 6a257cf13e..d47a89c7fb 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -428,6 +428,7 @@ def fusion_with_nested_sdfg_0(A: dace.int32[10], B: dace.int32[10], C: dace.int3 A[i] = tmp[i] * 2 sdfg = fusion_with_nested_sdfg_0.to_sdfg(simplify=True) + sdfg.view() # Because the transformation refuses to fuse dynamic edges. # We have to eliminate them. @@ -975,6 +976,101 @@ def test_fusion_non_strict_dataflow_implicit_dependency(): validate_all=True, ) assert count == 0 + + +def test_inner_map_dependency(): + sdfg = dace.SDFG("inner_map_dependency_sdfg") + state = sdfg.add_state(is_start_block=True) + + name_arrays = ["A", "T", "C"] + for aname in name_arrays: + sdfg.add_array( + name, + shape=(10,), + transient=False, + ) + sdfg.arrays["T"].transient = True + sdfg.add_scalar( + "s", + dtype=dace.float64, + transient=True, + ) + A, T, C = (state.add_access(name) for name in name_arrays) + s1, s2 = (state.add_access("s") for _ in range(2)) + + me1, mx1 = state.add_map( + "map_1", + ndrange={"__i0": "0:10"}, + ) + tsklt1 = state.add_tasklet( + "tskl1", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 + 1.0", + ) + + state.add_edge( + A, + None, + me1, + "IN_A", + dace.Memlet("A[0:10]") + ) + state.add_edge( + me1, + "OUT_A", + s1, + None, + dace.Memlet("A[__i0] -> [0]") + ) + me1.add_in_connector("IN_A") + me1.add_out_connector("IUT_A") + state.add_edge( + s1, + None, + tskl1, + dace.Memlet("s[0]") + ) + state.add_edge( + tsklt1, + "__out", + mx1, + "IN_T", + dace.Memlet("T[__i0]") + ) + state.add_edge( + mx1, + "OUT_T", + T, + None, + dace.Memlet("T[0:10]") + ) + mx1.add_in_connector("IN_T") + mx1.add_out_connector("OUT_T") + + + + + + + + + + + + + + + + + + + + + + + + if __name__ == '__main__': @@ -1002,5 +1098,3 @@ def test_fusion_non_strict_dataflow_implicit_dependency(): test_offset_correction_scalar_read() test_offset_correction_empty() test_different_offsets() - print("SUCCESS") - From e2285f0fef4ebc0f84498d5206ff4f9afe293dc2 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 16 Dec 2024 08:33:31 +0100 Subject: [PATCH 096/115] Made some modification, time to save. --- dace/transformation/dataflow/map_fusion.py | 147 ++++++++++++++------- 1 file changed, 96 insertions(+), 51 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index b080ae18d3..8334d07814 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -119,6 +119,11 @@ def can_be_applied( - Tests if there are read write dependencies. - Tests if the decomposition exists. """ + # To ensures that the `{src,dst}_subset` are properly set, run initialization. + # See [issue 1708](https://github.com/spcl/dace/issues/1703) + for edge in graph.edges(): + edge.data.try_initialize(sdfg, graph, edge) + first_map_entry: nodes.MapEntry = graph.entry_node(self.first_map_exit) first_map_exit: nodes.MapExit = self.first_map_exit second_map_entry: nodes.MapEntry = self.second_map_entry @@ -153,6 +158,7 @@ def can_be_applied( if self.has_read_write_dependency( first_map_entry=first_map_entry, second_map_entry=second_map_entry, + param_repl=param_repl, state=graph, sdfg=sdfg, ): @@ -189,12 +195,10 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non graph: The SDFG state we are operating on. sdfg: The SDFG we are operating on. """ - # NOTE: `self.map_*` actually stores the ID of the node. - # once we start adding and removing nodes it seems that their ID changes. - # Thus we have to save them here, this is a known behaviour in DaCe. - assert isinstance(graph, dace.SDFGState) - assert isinstance(self.first_map_exit, nodes.MapExit) - assert isinstance(self.second_map_entry, nodes.MapEntry) + # To ensures that the `{src,dst}_subset` are properly set, run initialization. + # See [issue 1708](https://github.com/spcl/dace/issues/1703) + for edge in graph.edges(): + edge.data.try_initialize(sdfg, graph, edge) first_map_exit: nodes.MapExit = self.first_map_exit second_map_entry: nodes.MapEntry = self.second_map_entry @@ -210,7 +214,8 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non ) # Now compute the partition. Because we have already renamed the parameters - # of the second Map, there is no need to perform any renaming. + # of the second Map, there is no need to perform any renaming, thus we can + # pass an empty `dict`. output_partition = self.partition_first_outputs( state=graph, sdfg=sdfg, @@ -221,6 +226,7 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non assert output_partition is not None # Make MyPy happy. pure_outputs, exclusive_outputs, shared_outputs = output_partition + # Now perform the actual rewiring, we handle each partition separately. if len(exclusive_outputs) != 0: self.handle_intermediate_set( intermediate_outputs=exclusive_outputs, @@ -250,9 +256,11 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non sdfg=sdfg, ) - # Above we have handled the input of the second map and moved them - # to the first map, now we must move the output of the first map - # to the second one, as this one is used. + # Now move the input of the second map, that has no connection to the first + # map, to the first map. This is needed because we will later delete the + # exit of the first map (which we have essentially handled above). Now + # we must handle the input of the second map (that has no connection to the + # first map) to the input of the first map. self.relocate_nodes( from_node=second_map_entry, to_node=first_map_entry, @@ -357,19 +365,20 @@ def partition_first_outputs( # cases, as handling them is essentially rerouting an edge, whereas # handling intermediate nodes is much more complicated. + # Empty Memlets are only allowed if they are in `\mathbb{P}`, which + # is also the only place they really make sense (for a map exit). + # Thus if we now found an empty Memlet we reject it. + if out_edge.data.is_empty(): + return None + # For us an intermediate node must always be an access node, because # everything else we do not know how to handle. It is important that # we do not test for non transient data here, because they can be # handled has shared intermediates. if not isinstance(intermediate_node, nodes.AccessNode): return None - if self.is_view(intermediate_node, sdfg): - return None - - # Empty Memlets are only allowed if they are in `\mathbb{P}`, which - # is also the only place they really make sense (for a map exit). - # Thus if we now found an empty Memlet we reject it. - if out_edge.data.is_empty(): + intermediate_desc: dace.data.Data = intermediate_node.desc(sdfg) + if self.is_view(intermediate_desc, sdfg): return None # It can happen that multiple edges converges at the `IN_` connector @@ -645,24 +654,11 @@ def handle_intermediate_set( if len(pre_exit_edges) != 1: raise NotImplementedError() pre_exit_edge = pre_exit_edges[0] - new_inter_shape_raw = symbolic.overapproximate(pre_exit_edge.data.subset.size()) - - # Over approximation will leave us with some unneeded size one dimensions. - # If they are removed some dace transformations (especially auto optimization) - # will have problems. - if not self.strict_dataflow: - squeezed_dims: List[int] = [] # These are the dimensions we removed. - new_inter_shape: List[int] = [] # This is the final shape of the new intermediate. - for dim, (proposed_dim_size, full_dim_size) in enumerate(zip(new_inter_shape_raw, inter_shape)): - if full_dim_size == 1: # Must be kept! - new_inter_shape.append(proposed_dim_size) - elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. - squeezed_dims.append(dim) - else: - new_inter_shape.append(proposed_dim_size) - else: - squeezed_dims = [] - new_inter_shape = list(new_inter_shape_raw) + + (new_inter_shape_raw, new_inter_shape, squeezed_dims) = self.compute_reduced_intermediate( + producer_subset=pre_exit_edge.data.dst_subset, + inter_desc=inter_desc, + ) # This is the name of the new "intermediate" node that we will create. # It will only have the shape `new_inter_shape` which is basically its @@ -695,8 +691,6 @@ def handle_intermediate_set( # Get the subset that defined into which part of the old intermediate # the old output edge wrote to. We need that to adjust the producer # Memlets, since they now write into the new (smaller) intermediate. - assert pre_exit_edge.data.data == inter_name - assert pre_exit_edge.data.dst_subset is not None producer_offset = self.compute_offset_subset( original_subset=pre_exit_edge.data.dst_subset, intermediate_desc=inter_desc, @@ -704,17 +698,17 @@ def handle_intermediate_set( producer_offset=None, ) - # Memlets have a lot of additional informations, such as dynamic. - # To ensure that we get all of them, we will now copy them and modify - # the one that was originally there. We also hope that propagate will - # set the rest for us correctly. + # Memlets have a lot of additional informations, to ensure that we get + # all of them, we have to do it this way. The main reason for this is + # to handle the case were the "Memlet reverse direction", i.e. `data` + # refers to the other end of the connection than before. + assert pre_exit_edge.data.dst_subset is not None + new_pre_exit_memlet_src_subset = copy.deepcopy(pre_exit_edge.data.src_subset) + new_pre_exit_memlet_dst_subset = subsets.Range.from_array(new_inter_desc) + new_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) new_pre_exit_memlet.data = new_inter_name - new_pre_exit_memlet.dst_subset = subsets.Range.from_array(new_inter_desc) - # New we will reroute the output Memlet, thus it will no longer pass - # through the Map exit but through the newly created intermediate. - # NOTE: We will delete the previous edge later. new_pre_exit_edge = state.add_edge( pre_exit_edge.src, pre_exit_edge.src_conn, @@ -723,6 +717,11 @@ def handle_intermediate_set( new_pre_exit_memlet, ) + # We can update `{src, dst}_subset` only after we have inserted the + # edge, this is because the direction of the Memlet might change. + new_pre_exit_edge.data.src_subset = new_pre_exit_memlet_src_subset + new_pre_exit_edge.data.dst_subset = new_pre_exit_memlet_dst_subset + # We now handle the MemletTree defined by this edge. # The newly created edge, only handled the last collection step. for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children(include_self=False): @@ -766,6 +765,7 @@ def handle_intermediate_set( out_conn_name = "OUT_" + in_conn_name[3:] for inner_edge in state.out_edges_by_connector(second_map_entry, out_conn_name): + # TODO(phimuell): Lift this restriction assert inner_edge.data.data == inter_name # DIRECTION!! # As for the producer side, we now read from a smaller array, @@ -804,6 +804,7 @@ def handle_intermediate_set( # Now we have to make sure that all consumers are properly updated. for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children(include_self=False): + # TODO(phimuell): Lift this restriction assert consumer_tree.edge.data.data == inter_name consumer_edge = consumer_tree.edge @@ -835,6 +836,9 @@ def handle_intermediate_set( del sdfg.arrays[inter_name] else: + # TODO(phimuell): Lift this restriction + assert pre_exit_edge.data.data == inter_name + # This is the shared mode, so we have to recreate the intermediate # node, but this time it is at the exit of the second map. state.remove_edge(pre_exit_edge) @@ -845,7 +849,6 @@ def handle_intermediate_set( # or preserve the output for the intermediate node. It is important # that we use the data that `preExitEdge` was used. final_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) - assert pre_exit_edge.data.data == inter_name final_pre_exit_memlet.other_subset = subsets.Range.from_array(new_inter_desc) new_pre_exit_conn = second_map_exit.next_connector() @@ -870,6 +873,50 @@ def handle_intermediate_set( state.remove_edge(out_edge) + def compute_reduced_intermediate( + self, + producer_subset: subsets.Range, + inter_desc: dace.data.Data, + ) -> Tuple[Tuple[int, ...], Tuple[int, ...], List[int]]: + """Compute the size of the new (reduced) intermediate. + + `MapFusion` does not only fuses map, but, depending on the situation, also + eliminates intermediate arrays between the two maps. To transmit data between + the two maps a new, but much smaller intermediate is needed. + + :return: The function returns a tuple with three values with the following meaning: + - The raw shape of the reduced intermediate. + - The cleared shape of the reduced intermediate, essentially the raw shape + with all shape 1 dimensions removed. + - Which dimensions of the raw shape have been removed to get the cleared shape. + + :param producer_subset: The subset that was used to write into the intermediate. + :param inter_desc: The data descriptor for the intermediate. + """ + assert producer_subset is not None + + # Over approximation will leave us with some unneeded size one dimensions. + # If they are removed some dace transformations (especially auto optimization) + # will have problems. + new_inter_shape_raw = symbolic.overapproximate(producer_subset.size()) + inter_shape = inter_desc.shape + if not self.strict_dataflow: + squeezed_dims: List[int] = [] # These are the dimensions we removed. + new_inter_shape: List[int] = [] # This is the final shape of the new intermediate. + for dim, (proposed_dim_size, full_dim_size) in enumerate(zip(new_inter_shape_raw, inter_shape)): + if full_dim_size == 1: # Must be kept! + new_inter_shape.append(proposed_dim_size) + elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. + squeezed_dims.append(dim) + else: + new_inter_shape.append(proposed_dim_size) + else: + squeezed_dims = [] + new_inter_shape = list(new_inter_shape_raw) + + return (tuple(new_inter_shape_raw), tuple(new_inter_shape), squeezed_dims) + + def compute_offset_subset( self, original_subset: subsets.Range, @@ -981,9 +1028,7 @@ def can_topologically_be_fused( # We will now check if there exists a remapping that of the map parameter param_repl = self.find_parameter_remapping(first_map=first_map_entry.map, second_map=second_map_entry.map) - if param_repl is None: - return None - return None + return param_repl def has_inner_read_write_dependency( @@ -1607,11 +1652,11 @@ def find_subsets( def is_view( self, - node: nodes.AccessNode, + node: Union[nodes.AccessNode, data.Data], sdfg: SDFG, ) -> bool: """Tests if `node` points to a view or not.""" - node_desc: data.Data = node.desc(sdfg) + node_desc: data.Data = node if isinstance(node, data.Data) else node.desc(sdfg) return isinstance(node_desc, data.View) From 4832c3c7f0138c984292512d21b2d9b1fa6befdc Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 16 Dec 2024 08:53:42 +0100 Subject: [PATCH 097/115] Updated the map fusion test a little bit. Added a new special case. --- tests/transformations/mapfusion_test.py | 273 +++++++++++++++++++++--- 1 file changed, 238 insertions(+), 35 deletions(-) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index d47a89c7fb..c24163b961 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -34,7 +34,13 @@ def apply_fusion( """ num_maps_before = count_node(sdfg, nodes.MapEntry) org_sdfg = copy.deepcopy(sdfg) - sdfg.apply_transformations_repeated(MapFusion, validate=True, validate_all=True) + with dace.config.temporary_config(): + dace.Config.set("optimizer", "match_exception", value=True) + sdfg.apply_transformations_repeated( + MapFusion(strict_dataflow=True), + validate=True, + validate_all=True + ) num_maps_after = count_node(sdfg, nodes.MapEntry) has_processed = False @@ -428,7 +434,6 @@ def fusion_with_nested_sdfg_0(A: dace.int32[10], B: dace.int32[10], C: dace.int3 A[i] = tmp[i] * 2 sdfg = fusion_with_nested_sdfg_0.to_sdfg(simplify=True) - sdfg.view() # Because the transformation refuses to fuse dynamic edges. # We have to eliminate them. @@ -978,25 +983,42 @@ def test_fusion_non_strict_dataflow_implicit_dependency(): assert count == 0 -def test_inner_map_dependency(): - sdfg = dace.SDFG("inner_map_dependency_sdfg") +def _make_inner_conflict_shared_scalar( + has_conflict: bool, +) -> dace.SDFG: + """Generate the SDFG for tests with the inner dependency. + + If `has_conflict` is `True` then a transient scalar is used inside both Map bodies. + Therefore, `MapFusion` should not be able to fuse them. + In case `has_conflict` is `False` then different scalars are used which allows + fusing the two maps. + """ + sdfg = dace.SDFG( + "inner_map_dependency_sdfg" + if has_conflict + else "inner_map_dependency_resolved_sdfg" + ) state = sdfg.add_state(is_start_block=True) name_arrays = ["A", "T", "C"] for aname in name_arrays: sdfg.add_array( - name, + aname, shape=(10,), + dtype=dace.float64, transient=False, ) sdfg.arrays["T"].transient = True - sdfg.add_scalar( - "s", - dtype=dace.float64, - transient=True, - ) - A, T, C = (state.add_access(name) for name in name_arrays) - s1, s2 = (state.add_access("s") for _ in range(2)) + + name_scalars = ["s", "s"] if has_conflict else ["s1", "s2"] + for sname in set(name_scalars): + sdfg.add_scalar( + sname, + dtype=dace.float64, + transient=True, + ) + A, T, C = (state.add_access(aname) for aname in name_arrays) + s1, s2 = (state.add_access(sname) for sname in name_scalars) me1, mx1 = state.add_map( "map_1", @@ -1009,71 +1031,250 @@ def test_inner_map_dependency(): code="__out = __in1 + 1.0", ) + # Create the first map series. state.add_edge( - A, - None, - me1, - "IN_A", + A, None, + me1, "IN_A", dace.Memlet("A[0:10]") ) + me1.add_in_connector("IN_A") state.add_edge( - me1, - "OUT_A", - s1, - None, + me1, "OUT_A", + s1, None, dace.Memlet("A[__i0] -> [0]") ) - me1.add_in_connector("IN_A") - me1.add_out_connector("IUT_A") + me1.add_out_connector("OUT_A") state.add_edge( - s1, - None, - tskl1, - dace.Memlet("s[0]") + s1, None, + tsklt1, "__in1", + dace.Memlet(f"{s1.data}[0]") ) state.add_edge( - tsklt1, - "__out", - mx1, - "IN_T", + tsklt1, "__out", + mx1, "IN_T", dace.Memlet("T[__i0]") ) + mx1.add_in_connector("IN_T") state.add_edge( - mx1, - "OUT_T", - T, - None, + mx1, "OUT_T", + T, None, dace.Memlet("T[0:10]") ) - mx1.add_in_connector("IN_T") mx1.add_out_connector("OUT_T") + # Create the second map. + me2, mx2 = state.add_map( + "map_2", + ndrange={"__i0": "0:10"}, + ) + tsklt2 = state.add_tasklet( + "tskl2", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 + 3.0", + ) + state.add_edge( + T, None, + me2, "IN_T", + dace.Memlet("T[0:10]") + ) + me2.add_in_connector("IN_T") + state.add_edge( + me2, "OUT_T", + s2, None, + dace.Memlet("T[__i0]") + ) + me2.add_out_connector("OUT_T") + state.add_edge( + s2, None, + tsklt2, "__in1", + dace.Memlet(f"{s2.data}[0]") + ) + state.add_edge( + tsklt2, "__out", + mx2, "IN_C", + dace.Memlet("C[__i0]") + ) + mx2.add_in_connector("IN_C") + state.add_edge( + mx2, "OUT_C", + C, None, + dace.Memlet("C[0:10]") + ) + mx2.add_out_connector("OUT_C") + sdfg.validate() + return sdfg +def test_inner_map_dependency(): + # Because the scalar is not shared the maps can not be fused. + sdfg = _make_inner_conflict_shared_scalar(has_conflict=True) + apply_fusion(sdfg, removed_maps=0, final_maps=2) +def test_inner_map_dependency_resolved(): + # Because the scalars are different, the scalar + sdfg = _make_inner_conflict_shared_scalar(has_conflict=False) + apply_fusion(sdfg, removed_maps=1, final_maps=1) +def _impl_fusion_intermediate_different_access(modified_shape: bool): + def ref(A, B): + T = np.zeros((A.shape[0] + 1, 2)) + for i in range(A.shape[0]): + T[i + 1, 0] = A[i] * 2 + T[i + 1, 1] = A[i] / 2 + for j in range(A.shape[0]): + B[j] = np.sin(T[j+1, 1]) + sdfg = dace.SDFG("fusion_intermediate_different_access_sdfg") + state = sdfg.add_state(is_start_block=True) + for name in "AB": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.add_array( + "T", + shape=(11, 2), + dtype=dace.float64, + transient=True, + ) + # For this intermediate, which essentially represents `[A[i] * 2, A[i] / 2]` in + # the reference above, there are two important remarks: + # - It exists because one data stream, i.e. `T[i + 1, 1]` would be dead data flow + # and currently the transformation can not handle this. + # - The strange shape is because the transformation can not handle this case. + # This is a limitation of the implementation. + sdfg.add_array( + "temp", + shape=( + (1, 2,) + if modified_shape + else (2,) + ), + dtype=dace.float64, + transient=True, + ) + A, B, T, temp = (state.add_access(name) for name in ["A", "B", "T", "temp"]) + me1, mx1 = state.add_map( + "first_map", + ndrange={"__i0": "0:10"}, + ) + state.add_edge( + A, None, + me1, "IN_A", + dace.Memlet("A[0:10]") + ) + me1.add_in_connector("IN_A") + me1.add_out_connector("OUT_A") + tsklt1_1 = state.add_tasklet( + "tsklt1_1", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 * 2.0", + ) + state.add_edge( + me1, "OUT_A", + tsklt1_1, "__in1", + dace.Memlet("A[__i0]") + ) + state.add_edge( + tsklt1_1, "__out", + temp, None, + dace.Memlet( + "temp[0, 0]" + if modified_shape + else "temp[0]" + ) + ) + tsklt1_2 = state.add_tasklet( + "tsklt1_2", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 / 2.0", + ) + state.add_edge( + me1, "OUT_A", + tsklt1_2, "__in1", + dace.Memlet("A[__i0]") + ) + state.add_edge( + tsklt1_2, "__out", + temp, None, + dace.Memlet( + "temp[0, 1]" + if modified_shape + else "temp[1]" + ) + ) + state.add_edge( + temp, None, + mx1, "IN_temp", + dace.Memlet( + "temp[0, 0:2] -> [__i0 + 1, 0:2]" + if modified_shape + else "temp[0:2] -> [__i0 + 1, 0:2]" + ) + ) + state.add_edge( + mx1, "OUT_temp", + T, None, + dace.Memlet("T[1:11, 0:2]") + ) + mx1.add_in_connector("IN_temp") + mx1.add_out_connector("OUT_temp") + state.add_mapped_tasklet( + "comp2", + map_ranges={"__i1": "0:10"}, + inputs={"__in1": dace.Memlet("T[__i1 + 1, 1]")}, + code="__out = math.sin(__in1)", + outputs={"__out": dace.Memlet("B[__i1]")}, + input_nodes={T}, + output_nodes={B}, + external_edges=True, + ) + sdfg.validate() + apply_fusion(sdfg, removed_maps=1, final_maps=1) + + args_ref = { + 'A': np.array(np.random.rand(10), dtype=np.float64, copy=True), + 'B': np.array(np.random.rand(10), dtype=np.float64, copy=True), + } + args_res = copy.deepcopy(args_ref) + ref(**args_ref) + sdfg(**args_res) + for arg in args_ref.keys(): + arg_ref = args_ref[arg] + arg_res = args_res[arg] + assert np.allclose(arg_ref, arg_res) +def test_fusion_intermediate_different_access(): + _impl_fusion_intermediate_different_access(modified_shape=False) +def test_fusion_intermediate_different_access_mod_shape(): + _impl_fusion_intermediate_different_access(modified_shape=True) if __name__ == '__main__': + test_fusion_intermediate_different_access() + test_fusion_intermediate_different_access_mod_shape() test_fusion_non_strict_dataflow_implicit_dependency() test_fusion_strict_dataflow_pointwise() test_fusion_strict_dataflow_not_pointwise() @@ -1098,3 +1299,5 @@ def test_inner_map_dependency(): test_offset_correction_scalar_read() test_offset_correction_empty() test_different_offsets() + test_inner_map_dependency() + test_inner_map_dependency_resolved() From fd3b48a1536cbf1b10bd1148d1a74bf306810f86 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 16 Dec 2024 09:10:45 +0100 Subject: [PATCH 098/115] Added a test for the next generation of MapFusion. --- tests/transformations/mapfusion_test.py | 123 ++++++++++++++++++++++++ 1 file changed, 123 insertions(+) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index c24163b961..8305fda1bb 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -6,6 +6,7 @@ import dace import copy import uuid +import pytest from dace import SDFG, SDFGState from dace.sdfg import nodes @@ -1272,6 +1273,128 @@ def test_fusion_intermediate_different_access_mod_shape(): _impl_fusion_intermediate_different_access(modified_shape=True) +@pytest.mark.skip(reason="This feature is not yet fully supported.") +def test_fusion_multiple_producers_consumers(): + """Multiple producer and consumer nodes. + + This test is very similar to the `test_fusion_intermediate_different_access()` + and `test_fusion_intermediate_different_access_mod_shape()` test, with the + exception that now full data is used in the second map. + However, currently `MapFusion` only supports a single producer, thus this test can + not run. + """ + def ref(A, B): + T = np.zeros((A.shape[0], 2)) + for i in range(A.shape[0]): + T[i, 0] = A[i] * 2 + T[i, 1] = A[i] / 2 + for j in range(A.shape[0]): + B[j] = np.sin(T[j, 1]) + np.cos(T[j, 0]) + + sdfg = dace.SDFG("fusion_multiple_producers_consumers_sdfg") + state = sdfg.add_state(is_start_block=True) + for name in "AB": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.add_array( + "T", + shape=(10, 2), + dtype=dace.float64, + transient=True, + ) + + A, B, T = (state.add_access(name) for name in ["A", "B", "T"]) + + me1, mx1 = state.add_map( + "first_map", + ndrange={"__i0": "0:10"}, + ) + + state.add_edge( + A, None, + me1, "IN_A", + dace.Memlet("A[0:10]") + ) + me1.add_in_connector("IN_A") + me1.add_out_connector("OUT_A") + + tsklt1_1 = state.add_tasklet( + "tsklt1_1", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 * 2.0", + ) + state.add_edge( + me1, "OUT_A", + tsklt1_1, "__in1", + dace.Memlet("A[__i0]") + ) + state.add_edge( + tsklt1_1, "__out", + mx1, "IN_T", + dace.Memlet("T[__i0, 0]") + ) + + tsklt1_2 = state.add_tasklet( + "tsklt1_2", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 / 2.0", + ) + state.add_edge( + me1, "OUT_A", + tsklt1_2, "__in1", + dace.Memlet("A[__i0]") + ) + state.add_edge( + tsklt1_2, "__out", + mx1, "IN_T", + dace.Memlet("T[__i0, 1]") + ) + mx1.add_in_connector("IN_T") + + state.add_edge( + mx1, "OUT_T", + T, None, + dace.Memlet("T[0:10, 0:2]"), + ) + mx1.add_out_connector("OUT_T") + + state.add_mapped_tasklet( + "comp2", + map_ranges={"__i1": "0:10"}, + inputs={ + "__in1": dace.Memlet("T[__i1, 1]"), + "__in2": dace.Memlet("T[__i1, 0]"), + }, + code="__out = math.sin(__in1) + math.cos(__in2)", + outputs={"__out": dace.Memlet("B[__i1]")}, + input_nodes={T}, + output_nodes={B}, + external_edges=True, + ) + sdfg.validate() + + apply_fusion(sdfg, removed_maps=1, final_maps=1) + + args_ref = { + 'A': np.array(np.random.rand(10), dtype=np.float64, copy=True), + 'B': np.array(np.random.rand(10), dtype=np.float64, copy=True), + } + args_res = copy.deepcopy(args_ref) + + ref(**args_ref) + sdfg(**args_res) + for arg in args_ref.keys(): + arg_ref = args_ref[arg] + arg_res = args_res[arg] + assert np.allclose(arg_ref, arg_res) + + if __name__ == '__main__': test_fusion_intermediate_different_access() test_fusion_intermediate_different_access_mod_shape() From 8fa7cb2563b98db6b5d4a40ba24e2f64638560e7 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 16 Dec 2024 09:21:03 +0100 Subject: [PATCH 099/115] Added a new test. --- tests/transformations/mapfusion_test.py | 118 ++++++++++++++++++++++++ 1 file changed, 118 insertions(+) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 8305fda1bb..0a408ab10b 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -1395,7 +1395,125 @@ def ref(A, B): assert np.allclose(arg_ref, arg_res) +def test_fusion_multiple_consumers(): + """The intermediate is consumed multiple times in the second map. + """ + def ref(A, B, C): + T = np.zeros_like(A) + for i in range(A.shape[0]): + T[i] = np.sin(A[i] * 2) + for j in range(A.shape[0]): + B[j] = T[j] * 3. + C[j] = T[j] - 1. + + sdfg = dace.SDFG("fusion_multiple_consumers_sdfg") + state = sdfg.add_state(is_start_block=True) + for name in "ABCT": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["T"].transient = True + + A, B, C, T = (state.add_access(name) for name in ["A", "B", "C", "T"]) + + state.add_mapped_tasklet( + "comp1", + map_ranges={"__i1": "0:10"}, + inputs={ + "__in1": dace.Memlet("A[__i1]"), + }, + code="__out = math.sin(2 * __in1)", + outputs={"__out": dace.Memlet("T[__i1]")}, + input_nodes={A}, + output_nodes={T}, + external_edges=True, + ) + + me2, mx2 = state.add_map( + "second_map", + ndrange={"__i0": "0:10"}, + ) + + state.add_edge( + T, None, + me2, "IN_T", + dace.Memlet("T[0:10]", volume=20) + ) + me2.add_in_connector("IN_T") + me2.add_out_connector("OUT_T") + + tsklt2_1 = state.add_tasklet( + "tsklt2_1", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 * 3.0", + ) + state.add_edge( + me2, "OUT_T", + tsklt2_1, "__in1", + dace.Memlet("T[__i0]") + ) + state.add_edge( + tsklt2_1, "__out", + mx2, "IN_B", + dace.Memlet("B[__i0]") + ) + + tsklt2_2 = state.add_tasklet( + "tsklt2_2", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 - 1.0", + ) + state.add_edge( + me2, "OUT_T", + tsklt2_2, "__in1", + dace.Memlet("T[__i0]") + ) + state.add_edge( + tsklt2_2, "__out", + mx2, "IN_C", + dace.Memlet("C[__i0]") + ) + mx2.add_in_connector("IN_B") + mx2.add_in_connector("IN_C") + + state.add_edge( + mx2, "OUT_B", + B, None, + dace.Memlet("B[0:10]"), + ) + state.add_edge( + mx2, "OUT_C", + C, None, + dace.Memlet("C[0:10]"), + ) + mx2.add_out_connector("OUT_B") + mx2.add_out_connector("OUT_C") + sdfg.validate() + + apply_fusion(sdfg, removed_maps=1, final_maps=1) + + args_ref = { + 'A': np.array(np.random.rand(10), dtype=np.float64, copy=True), + 'B': np.array(np.random.rand(10), dtype=np.float64, copy=True), + 'C': np.array(np.random.rand(10), dtype=np.float64, copy=True), + } + args_res = copy.deepcopy(args_ref) + + ref(**args_ref) + sdfg(**args_res) + for arg in args_ref.keys(): + arg_ref = args_ref[arg] + arg_res = args_res[arg] + assert np.allclose(arg_ref, arg_res) + + if __name__ == '__main__': + test_fusion_multiple_consumers() test_fusion_intermediate_different_access() test_fusion_intermediate_different_access_mod_shape() test_fusion_non_strict_dataflow_implicit_dependency() From 91391493b0dc7bb543bf9adfa5d7db547aa9f135 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 16 Dec 2024 11:40:42 +0100 Subject: [PATCH 100/115] The test `test_fusion_with_nested_sdfg_0` is now explicitly constructured. Before the output edges were before set to dynamic. However, this was not true as it was always set, thus the new map fusion did not fuse them. My first attempt was to just disable the `dynamic` property, but now the SDFG is generated manually. It is almost the same, but uses lesss symbol, as it was simpler to implement it this way, and we are now using float. --- tests/transformations/mapfusion_test.py | 149 ++++++++++++++++++++++-- 1 file changed, 139 insertions(+), 10 deletions(-) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 0a408ab10b..b8b40d21a8 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -423,9 +423,8 @@ def inner_product(A: dace.float32[N], B: dace.float32[N], out: dace.float32[1]): def test_fusion_with_nested_sdfg_0(): - @dace.program - def fusion_with_nested_sdfg_0(A: dace.int32[10], B: dace.int32[10], C: dace.int32[10]): - tmp = np.empty([10], dtype=np.int32) + def ref(A, B, C): + tmp = np.zeros_like(A) for i in dace.map[0:10]: if C[i] < 0: tmp[i] = B[i] - A[i] @@ -433,14 +432,129 @@ def fusion_with_nested_sdfg_0(A: dace.int32[10], B: dace.int32[10], C: dace.int3 tmp[i] = B[i] + A[i] for i in dace.map[0:10]: A[i] = tmp[i] * 2 - - sdfg = fusion_with_nested_sdfg_0.to_sdfg(simplify=True) - # Because the transformation refuses to fuse dynamic edges. - # We have to eliminate them. - for state in sdfg.states(): - for edge in state.edges(): - edge.data.dynamic = False + def _make_sdfg() -> dace.SDFG: + sdfg = SDFG("fusion_with_nested_sdfg_0") + state = sdfg.add_state(is_start_block=True) + + for name in "ABCT": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["T"].transient = True + + me1, mx1 = state.add_map("first_map", ndrange={"__i0": "0:10"}) + nsdfg = state.add_nested_sdfg( + sdfg=_make_nested_sdfg(), + parent=sdfg, + inputs={"a", "b", "c"}, + outputs={"t"}, + symbol_mapping={}, + ) + + for name in "ABC": + state.add_edge( + state.add_access(name), None, + me1, "IN_" + name, + dace.Memlet(f"{name}[0:10]"), + ) + me1.add_in_connector("IN_" + name) + state.add_edge( + me1, "OUT_" + name, + nsdfg, name.lower(), + dace.Memlet(f"{name}[__i0]"), + ) + me1.add_out_connector("OUT_" + name) + state.add_edge( + nsdfg, "t", + mx1, "IN_T", + dace.Memlet("T[__i0]"), + ) + T = state.add_access("T") + state.add_edge( + mx1, "OUT_T", + T, None, + dace.Memlet("T[0:10]"), + ) + mx1.add_in_connector("IN_T") + mx1.add_out_connector("OUT_T") + + state.add_mapped_tasklet( + "comp2", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("T[__i0]")}, + code="__out = __in1 * 2", + outputs={"__out": dace.Memlet("A[__i0]")}, + input_nodes={T}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + def _make_nested_sdfg() -> dace.SDFG: + sdfg = SDFG("Nested") + + for name in "abct": + sdfg.add_scalar( + name, + dtype=dace.float64, + transient=False, + ) + + state_head = sdfg.add_state("head_state", is_start_block=True) + state_if_guard = sdfg.add_state("state_if_guard") + sdfg.add_edge( + state_head, + state_if_guard, + dace.InterstateEdge( + condition="1", + assignments={"__tmp2": "c < 0.0"}, + ) + ) + + def _make_branch_tasklet( + state: dace.SDFGState, + code: str, + ) -> None: + tasklet = state.add_tasklet( + state.label + "_tasklet", + inputs={"__in1", "__in2"}, + code=code, + outputs={"__out"}, + ) + state.add_edge( + state.add_access("b"), None, + tasklet, "__in1", + dace.Memlet("b[0]"), + ) + state.add_edge( + state.add_access("a"), None, + tasklet, "__in2", + dace.Memlet("a[0]"), + ) + state.add_edge( + tasklet, "__out", + state.add_access("t"), None, + dace.Memlet("t[0]"), + ) + + state_trueb = sdfg.add_state("true_branch") + _make_branch_tasklet(state_trueb, "__out = __in1 - __in2") + state_falseb = sdfg.add_state("false_branch") + _make_branch_tasklet(state_falseb, "__out = __in1 + __in2") + state_if_end = sdfg.add_state("if_join") + + sdfg.add_edge(state_if_guard, state_trueb, dace.InterstateEdge(condition="__tmp2")) + sdfg.add_edge(state_if_guard, state_falseb, dace.InterstateEdge(condition="not __tmp2")) + sdfg.add_edge(state_falseb, state_if_end, dace.InterstateEdge()) + sdfg.add_edge(state_trueb, state_if_end, dace.InterstateEdge()) + sdfg.validate() + return sdfg + + sdfg = _make_sdfg() apply_fusion(sdfg) for sd in sdfg.all_sdfgs_recursive(): @@ -453,6 +567,21 @@ def fusion_with_nested_sdfg_0(A: dace.int32[10], B: dace.int32[10], C: dace.int3 assert isinstance(dst, dace.nodes.AccessNode) + args_ref = { + 'A': np.array(np.random.rand(10), dtype=np.float64, copy=True), + 'B': np.array(np.random.rand(10), dtype=np.float64, copy=True), + 'C': np.array(np.random.rand(10) - 0.5, dtype=np.float64, copy=True), + } + args_res = copy.deepcopy(args_ref) + + ref(**args_ref) + sdfg(**args_res) + for arg in args_ref.keys(): + arg_ref = args_ref[arg] + arg_res = args_res[arg] + assert np.allclose(arg_ref, arg_res), f"Failed in {arg}" + + def test_fusion_with_nested_sdfg_1(): @dace.program From 0a5aeaf94006c572cceefa2b5065aa037f3c2fc6 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 16 Dec 2024 12:15:42 +0100 Subject: [PATCH 101/115] Allowed that consumer edge in MapFusion are dynamic. For such edges we are sure that the data exists, so it is just a conditional read, which is fine. --- dace/transformation/dataflow/map_fusion.py | 10 ++++++---- tests/transformations/mapfusion_test.py | 10 +++------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index 8334d07814..9ec5bf3463 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -405,6 +405,7 @@ def partition_first_outputs( if isinstance(producer_edge.src, nodes.AccessNode) and self.is_view(producer_edge.src, sdfg): return None if producer_edge.data.dynamic: + # TODO(phimuell): Find out if this restriction could be lifted, but it is unlikely. return None if producer_edge.data.wcr is not None: return None @@ -456,14 +457,15 @@ def partition_first_outputs( # Now we look at all edges that leave the second map entry, i.e. the # edges that feeds the consumer and define what is read inside the map. # We do not check them, but collect them and inspect them. - # NOTE: The subset still uses the old iteration variables. + # NOTE1: The subset still uses the old iteration variables. + # NOTE2: In case of consumer Memlet we explicitly allow dynamic Memlets. + # This is different compared to the producer Memlet. The reason is + # because in a consumer the data is conditionally read, so the data + # has to exists anyway. for inner_consumer_edge in state.out_edges_by_connector( second_map_entry, "OUT_" + intermediate_node_out_edge.dst_conn[3:]): if inner_consumer_edge.data.src_subset is None: return None - if inner_consumer_edge.data.dynamic: - # TODO(phimuell): Is this restriction necessary, I am not sure. - return None consumer_subsets.append(inner_consumer_edge.data.src_subset) assert found_second_map, f"Found '{intermediate_node}' which looked like a pure node, but is not one." assert len(consumer_subsets) != 0 diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index b8b40d21a8..dd6aa077e8 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -583,7 +583,9 @@ def _make_branch_tasklet( def test_fusion_with_nested_sdfg_1(): - + + # As a side effect this test also ensures that dynamic consumer edges, does not + # impact fusing, i.e. allow that fusion can take place. @dace.program def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int32[10]): tmp = np.empty([10], dtype=np.int32) @@ -600,12 +602,6 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3 B[i] = tmp[i] * 2 sdfg = fusion_with_nested_sdfg_1.to_sdfg(simplify=True) - - # Because the transformation refuses to fuse dynamic edges. - # We have to eliminate them. - for state in sdfg.states(): - for edge in state.edges(): - edge.data.dynamic = False apply_fusion(sdfg) if len(sdfg.states()) != 1: From e2bc10d0b017a2e0fbea0a15956637e8499abd4c Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 16 Dec 2024 13:10:19 +0100 Subject: [PATCH 102/115] Changed the doc string to the Sphinx one. --- dace/transformation/dataflow/map_fusion.py | 188 ++++++++------------- 1 file changed, 74 insertions(+), 114 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index 9ec5bf3463..a0e479af6a 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -27,24 +27,16 @@ class MapFusion(transformation.SingleStateTransformation): Furthermore, shared intermediates, see `partition_first_outputs()` will only be created if the data is not referred downstream in the dataflow. - Args: - only_inner_maps: Only match Maps that are internal, i.e. inside another Map. - only_toplevel_maps: Only consider Maps that are at the top. - strict_dataflow: Which dataflow mode should be used, see above. - - Notes: - - This transformation modifies more nodes than it matches. - - After the transformation has been applied simplify should be run to remove + :param only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + :param only_toplevel_maps: Only consider Maps that are at the top. + :param strict_dataflow: Which dataflow mode should be used, see above. + + :note: This transformation modifies more nodes than it matches. + :note: After the transformation has been applied simplify should be run to remove some dead data flow, that was introduced to ensure validity. - - A `MapFusion` obejct can be initialized and be reused. However, + :note: A `MapFusion` obejct can be initialized and be reused. However, after new access nodes are added to any state, it is no longer valid to use the object. - - Todo: - - Consider the case that only shared nodes are created (thus no inspection of - the graph is needed) and make all shared. Then use the dead dataflow - elimination transformation to get rid of the ones we no longer need. - - Increase the applicability. """ # Pattern Nodes @@ -115,9 +107,9 @@ def can_be_applied( """Tests if the matched Maps can be merged. The two Maps are mergeable iff: - - Checks general requirements, see `can_topologically_be_fused()`. - - Tests if there are read write dependencies. - - Tests if the decomposition exists. + * Checks general requirements, see `can_topologically_be_fused()`. + * Tests if there are read write dependencies. + * Tests if the decomposition exists. """ # To ensures that the `{src,dst}_subset` are properly set, run initialization. # See [issue 1708](https://github.com/spcl/dace/issues/1703) @@ -191,9 +183,8 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non By assumption we do not have to rename anything. - Args: - graph: The SDFG state we are operating on. - sdfg: The SDFG we are operating on. + :param graph: The SDFG state we are operating on. + :param sdfg: The SDFG we are operating on. """ # To ensures that the `{src,dst}_subset` are properly set, run initialization. # See [issue 1708](https://github.com/spcl/dace/issues/1703) @@ -295,15 +286,15 @@ def partition_first_outputs( The output edges of the first map are partitioned into three distinct sets, defined as follows: - - Pure Output Set `\mathbb{P}`: + * Pure Output Set `\mathbb{P}`: These edges exits the first map and does not enter the second map. These outputs will be simply be moved to the output of the second map. - - Exclusive Intermediate Set `\mathbb{E}`: + * Exclusive Intermediate Set `\mathbb{E}`: Edges in this set leaves the first map exit, enters an access node, from where a Memlet then leads immediately to the second map. The memory referenced by this access node is not used anywhere else, thus it can be removed. - - Shared Intermediate Set `\mathbb{S}`: + * Shared Intermediate Set `\mathbb{S}`: These edges are very similar to the one in `\mathbb{E}` except that they are used somewhere else, thus they can not be removed and have to be recreated as output of the second map. @@ -531,11 +522,10 @@ def relocate_nodes( `from_node` has degree zero. The function assumes that the parameter renaming was already done. - Args: - from_node: Node from which the edges should be removed. - to_node: Node to which the edges should reconnect. - state: The state in which the operation happens. - sdfg: The SDFG that is modified. + :param from_node: Node from which the edges should be removed. + :param to_node: Node to which the edges should reconnect. + :param state: The state in which the operation happens. + :param sdfg: The SDFG that is modified. """ # Now we relocate empty Memlets, from the `from_node` to the `to_node` @@ -624,17 +614,15 @@ def handle_intermediate_set( the SDFG. While in shared mode the intermediate node will be preserved. The function assumes that the parameter renaming was already done. - Args: - intermediate_outputs: The set of outputs, that should be processed. - state: The state in which the map is processed. - sdfg: The SDFG that should be optimized. - first_map_exit: The exit of the first/top map. - second_map_entry: The entry of the second map. - second_map_exit: The exit of the second map. - is_exclusive_set: If `True` `intermediate_outputs` is the exclusive set. - - Notes: - Before the transformation the `state` does not have to be valid and + :param intermediate_outputs: The set of outputs, that should be processed. + :param state: The state in which the map is processed. + :param sdfg: The SDFG that should be optimized. + :param first_map_exit: The exit of the first/top map. + :param second_map_entry: The entry of the second map. + :param second_map_exit: The exit of the second map. + :param is_exclusive_set: If `True` `intermediate_outputs` is the exclusive set. + + :note: Before the transformation the `state` does not have to be valid and after this function has run the state is (most likely) invalid. """ @@ -887,10 +875,10 @@ def compute_reduced_intermediate( the two maps a new, but much smaller intermediate is needed. :return: The function returns a tuple with three values with the following meaning: - - The raw shape of the reduced intermediate. - - The cleared shape of the reduced intermediate, essentially the raw shape + * The raw shape of the reduced intermediate. + * The cleared shape of the reduced intermediate, essentially the raw shape with all shape 1 dimensions removed. - - Which dimensions of the raw shape have been removed to get the cleared shape. + * Which dimensions of the raw shape have been removed to get the cleared shape. :param producer_subset: The subset that was used to write into the intermediate. :param inter_desc: The data descriptor for the intermediate. @@ -937,12 +925,11 @@ def compute_offset_subset( case the function computes the correction for the consumer side, i.e. the memlet tree that originates at `intermediate_desc`. - Args: - original_subset: The original subset that was used to write into the - intermediate, must be renamed to the final map parameter. - intermediate_desc: The original intermediate data descriptor. - map_params: The parameter of the final map. - producer_offset: The correction that was applied to the producer side. + :param original_subset: The original subset that was used to write into the + intermediate, must be renamed to the final map parameter. + :param intermediate_desc: The original intermediate data descriptor. + :param map_params: The parameter of the final map. + :param producer_offset: The correction that was applied to the producer side. """ assert not isinstance(intermediate_desc, data.View) final_offset: subsets.Range = None @@ -996,9 +983,9 @@ def can_topologically_be_fused( This function only checks constrains that are common between serial and parallel map fusion process, which includes: - - The scope of the maps. - - The scheduling of the maps. - - The map parameters. + * The scope of the maps. + * The scheduling of the maps. + * The map parameters. :return: If the maps can not be topologically fused the function returns `None`. If they can be fused the function returns `dict` that describes parameter @@ -1045,12 +1032,12 @@ def has_inner_read_write_dependency( The function will scan and anaysize the body of the two Maps and look for inconsistencies. To detect them the function will scan the body of the maps and examine the all AccessNodes and apply the following rules: - - If an AccessNode refers to a View, it is ignored. Because the source is + * If an AccessNode refers to a View, it is ignored. Because the source is either on the outside, in which case `has_read_write_dependency()` takes care of it, or the data source is inside the Map body itself. - - An inconsistency is detected, if in each bodies there exists an AccessNode + * An inconsistency is detected, if in each bodies there exists an AccessNode that refer to the same data. - - An inconsistency is detected, if there exists an AccessNode that refers + * An inconsistency is detected, if there exists an AccessNode that refers to non transient data. This is an implementation detail and could be lifted. @@ -1106,10 +1093,10 @@ def has_read_write_dependency( """Test if there is a read write dependency between the two maps to be fused. The function checks two different things. - - The function will make sure that there is no read write dependency between + * The function will make sure that there is no read write dependency between the input and output of the fused maps. For that it will inspect the respective subsets. - - The second part partially checks the intermediate nodes, it mostly ensures + * The second part partially checks the intermediate nodes, it mostly ensures that there are not views and that they are not used as inputs or outputs at the same time. However, the function will not check for read write conflicts in this set, this is done in the partition function. @@ -1242,8 +1229,7 @@ def test_if_subsets_are_point_wise(self, subsets_to_check: List[subsets.Subset]) If the subsets originates from different maps, then they must have been renamed. - Args: - subsets_to_check: The list of subsets that should be checked. + :param subsets_to_check: The list of subsets that should be checked. """ assert len(subsets_to_check) > 1 @@ -1281,21 +1267,19 @@ def is_shared_data( """Tests if `data` is shared data, an can not be removed. Interstate data is used to transmit data, this includes: - - The data is referred in multiple states. - - The data is referred to multiple times in the same state, either the state + * The data is referred in multiple states. + * The data is referred to multiple times in the same state, either the state has multiple access nodes for that data or an access node has an out degree larger than one. - - The data is read inside interstate edges. + * The data is read inside interstate edges. This definition is stricter than the one employed by `SDFG.shared_transients()`, as it also includes usage within a state. - Args: - transient: The transient that should be checked. - sdfg: The SDFG containing the array. + :param transient: The transient that should be checked. + :param sdfg: The SDFG containing the array. - Note: - The function computes the this set once for every SDFG and then caches it. + :note: The function computes the this set once for every SDFG and then caches it. There is no mechanism to detect if the cache must be evicted. However, as long as no additional data is added to the SDFG, there is no problem. """ @@ -1312,8 +1296,7 @@ def _compute_shared_data_in( See the documentation for `self.is_shared_data()` for a description. - Args: - sdfg: The SDFG for which the set of shared data should be computed. + :param sdfg: The SDFG for which the set of shared data should be computed. """ # Shared data of this SDFG. shared_data: Set[str] = set() @@ -1393,9 +1376,8 @@ def find_parameter_remapping(self, first_map: nodes.Map, second_map: nodes.Map) then the function returns an empty `dict`. If no remapping exists, then the function will return `None`. - Args: - first_map: The first map (these parameters will be replaced). - second_map: The second map, these parameters acts as source. + :param first_map: The first map (these parameters will be replaced). + :param second_map: The second map, these parameters acts as source. """ # The parameter names @@ -1476,11 +1458,10 @@ def rename_map_parameters( handled correct. The function assumes that a proper replacement exists. The replacement is computed by calling `self.find_parameter_remapping()`. - Args: - first_map: The first map (these are the final parameter). - second_map: The second map, this map will be replaced. - second_map_entry: The entry node of the second map. - state: The SDFGState on which we operate. + :param first_map: The first map (these are the final parameter). + :param second_map: The second map, this map will be replaced. + :param second_map_entry: The entry node of the second map. + :param state: The SDFGState on which we operate. """ # Compute the replacement dict. repl_dict: Dict[str, str] = self.find_parameter_remapping(first_map=first_map, second_map=second_map) @@ -1515,10 +1496,9 @@ def is_node_reachable_from( to `end` the function returns `True`. If the node is never found `False` is returned. - Args: - graph: The graph to operate on. - begin: The start of the DFS. - end: The node that should be located. + :param graph: The graph to operate on. + :param begin: The start of the DFS. + :param end: The node that should be located. """ def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: @@ -1553,10 +1533,9 @@ def _is_data_accessed_downstream( If no such node is found it will return `False`. Note that the node `begin` will be ignored. - Args: - data: The name of the data to look for. - graph: The graph to explore. - begin: The node to start exploration; The node itself is ignored. + :param data: The name of the data to look for. + :param graph: The graph to explore. + :param begin: The node to start exploration; The node itself is ignored. """ def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: return (edge.dst for edge in graph.out_edges(node)) @@ -1587,9 +1566,8 @@ def get_access_set( The function returns a set that contains all access nodes that were found. It is important that this set will also contain views. - Args: - scope_node: The scope node that should be evaluated. - state: The state in which we operate. + :param scope_node: The scope node that should be evaluated. + :param state: The state in which we operate. """ if isinstance(scope_node, nodes.MapEntry): get_edges = lambda node: state.in_edges(node) @@ -1674,35 +1652,17 @@ def track_view( access node. For convenience, if `view` is not a `View` the argument will be returned. - Args: - view: The view that should be traced. - state: The state in which we operate. - sdfg: The SDFG on which we operate. + :param view: The view that should be traced. + :param state: The state in which we operate. + :param sdfg: The SDFG on which we operate. """ # Test if it is a view at all, if not return the passed node as source. if not self.is_view(view, sdfg): return view - # First determine if the view is used for reading or writing. - curr_edge = dace.sdfg.utils.get_view_edge(state, view) - if curr_edge is None: - raise RuntimeError(f"Failed to determine the direction of the view '{view}'.") - if curr_edge.dst_conn == "views": - # The view is used for reading. - next_node = lambda curr_edge: curr_edge.src - elif curr_edge.src_conn == "views": - # The view is used for writing. - next_node = lambda curr_edge: curr_edge.dst - else: - raise RuntimeError(f"Failed to determine the direction of the view '{view}' | {curr_edge}.") - - # Now trace the view back. - org_view = view - view = next_node(curr_edge) - while self.is_view(view, sdfg): - curr_edge = dace.sdfg.utils.get_view_edge(state, view) - if curr_edge is None: - raise RuntimeError(f"View tracing of '{org_view}' failed at note '{view}'.") - view = next_node(curr_edge) - return view + # This is the node that defines the view. + defining_node = dace.sdfg.utils.get_last_view_node(state, view) + assert isinstance(defining_node, nodes.AccessNode) + assert not self.is_view(defining_node, sdfg) + return defining_node From aa3619fd02fac6f0aa92341e6f7d39205ac44d9a Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 16 Dec 2024 14:01:30 +0100 Subject: [PATCH 103/115] Fixed some missing test. --- tests/transformations/apply_to_test.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/transformations/apply_to_test.py b/tests/transformations/apply_to_test.py index 30ccb53f28..34ff114ac5 100644 --- a/tests/transformations/apply_to_test.py +++ b/tests/transformations/apply_to_test.py @@ -60,11 +60,11 @@ def test_applyto_pattern(): assert MapFusion.can_be_applied_to( sdfg, - map_exit_1=mult_exit, - intermediate_access_node=access_node, - map_entry_2=add_entry + first_map_exit=mult_exit, + array=access_node, + second_map_entry=add_entry ) - MapFusion.apply_to(sdfg, map_exit_1=mult_exit, intermediate_access_node=access_node, map_entry_2=add_entry) + MapFusion.apply_to(sdfg, first_map_exit=mult_exit, array=access_node, second_map_entry=add_entry) assert len([node for node in state.nodes() if isinstance(node, dace.nodes.MapEntry)]) == 1 @@ -88,9 +88,9 @@ def test_applyto_pattern_2(): assert not MapFusion.can_be_applied_to( sdfg, - map_exit_1=map_exit_1, - intermediate_access_node=tmp, - map_entry_2=map_entry_2 + first_map_exit=map_exit_1, + array=tmp, + second_map_entry=map_entry_2 ) with pytest.raises( ValueError, @@ -99,9 +99,9 @@ def test_applyto_pattern_2(): MapFusion.apply_to( sdfg, verify=True, - map_exit_1=map_exit_1, - intermediate_access_node=tmp, - map_entry_2=map_entry_2 + first_map_exit=map_exit_1, + array=tmp, + second_map_entry=map_entry_2 ) From a740d16b63ea19de592d9088e437545b4bffa6c3 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 16 Dec 2024 15:43:39 +0100 Subject: [PATCH 104/115] Updated the description of the transformation. --- dace/transformation/dataflow/map_fusion.py | 43 ++++++++++++++++++---- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index a0e479af6a..a077246574 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -12,13 +12,39 @@ @properties.make_properties class MapFusion(transformation.SingleStateTransformation): - """Fuse two serial maps together. - - The transformation combines two maps into one that are connected through some - access nodes. Conceptually this transformation removes the exit of the first - or upper map and the entry of the lower or second map and then rewrites the - connections appropriately. Depending on the situation the transformation will - either fully remove or make the intermediate a new output of the second map. + """Implements the MapFusion transformation. + + + From a high level perspective it will remove the MapExit node for the first and the MapEntry node of + the second Map. Then it will rewire and modify the Memlets to bypass the intermediate node and instead + go through a new intermediate node. This new intermediate node is much smaller because it has no longer + to absorb the whole output of the first map, but only the data that is produced by a single iteration + of the first map. It is important to note that it is not always possible to fully remove the intermediate + node, for example it is used somewhere else, see `is_shared_data()`. Thus by merging the two Maps together + the transformation will reduce the memory footprint because the intermediate nodes can be removed. + An example would be the following: + ```python + for i in range(N): + T[i] = foo(A[i]) + for j in range(N): + B[j] = bar(T[i]) + ``` + which would be translated into: + ```python + for i in range(N): + temp: scalar = foo(A[i]) + B[i] = bar(temp) + ``` + + The checks that two Maps can be fused are quite involved, however, they essentially check: + * If the two Maps cover the same iteration space, essentially have the same start, stop and + iteration , see `find_parameter_remapping()`. + * Furthermore, they verify if the new fused Map did not introduce read write conflict, + essentially it tests if the data is pointwise, i.e. what is read is also written, + see `has_read_write_dependency()`. + * Then it will examine the intermediate data. This will essentially test if the data that + is needed by a single iteration of the second Map is produced by a single iteration of + the first Map, see `partition_first_outputs()`. By default `strict_dataflow` is enabled. In this mode the transformation is more conservative. The main difference is, that it will not adjust the @@ -1015,7 +1041,8 @@ def can_topologically_be_fused( if scope[first_map_entry] is not None: return None - # We will now check if there exists a remapping that of the map parameter + # We will now check if we can rename the Map parameter of the second Map such that they + # match the one of the first Map. param_repl = self.find_parameter_remapping(first_map=first_map_entry.map, second_map=second_map_entry.map) return param_repl From d07e2c56840b02628aaccba048ec1f1e6add3efd Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 16 Dec 2024 15:44:13 +0100 Subject: [PATCH 105/115] Added a new test. --- tests/transformations/mapfusion_test.py | 65 +++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index dd6aa077e8..6a1dc880a7 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -1637,7 +1637,72 @@ def ref(A, B, C): assert np.allclose(arg_ref, arg_res) +def test_fusion_different_global_accesses(): + + def ref(A, B): + T = np.zeros_like(A) + for i in range(10): + T[i] = A[i] - B[i + 1] + for i in range(10): + A[i] = np.sin(T[i]) + B[i + 1] = np.cos(T[i]) + + sdfg = dace.SDFG("fusion_different_global_accesses_sdfg") + state = sdfg.add_state(is_start_block=True) + for name in "ABT": + sdfg.add_array( + name, + shape=(11,), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["T"].transient = True + T = state.add_access("T") + + state.add_mapped_tasklet( + "first_comp", + map_ranges={"__i0": "0:10"}, + inputs={ + "__in1": dace.Memlet("A[__i0]"), + "__in2": dace.Memlet("B[__i0 + 1]") + }, + code="__out = __in1 - __in2", + outputs={"__out": dace.Memlet("T[__i0]")}, + output_nodes={T}, + external_edges=True, + ) + state.add_mapped_tasklet( + "second_comp", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("T[__i0]")}, + code="__out1 = math.sin(__in1)\n__out2 = math.cos(__in1)", + outputs={ + "__out1": dace.Memlet("A[__i0]"), + "__out2": dace.Memlet("B[__i0 + 1]"), + }, + input_nodes={T}, + external_edges=True, + ) + sdfg.validate() + + apply_fusion(sdfg, removed_maps=1, final_maps=1) + + args_ref = { + 'A': np.array(np.random.rand(11), dtype=np.float64, copy=True), + 'B': np.array(np.random.rand(11), dtype=np.float64, copy=True), + } + args_res = copy.deepcopy(args_ref) + + ref(**args_ref) + sdfg(**args_res) + for arg in args_ref.keys(): + arg_ref = args_ref[arg] + arg_res = args_res[arg] + assert np.allclose(arg_ref, arg_res) + + if __name__ == '__main__': + test_fusion_different_global_accesses() test_fusion_multiple_consumers() test_fusion_intermediate_different_access() test_fusion_intermediate_different_access_mod_shape() From 5c354c6a3b3f1c998558a026e006285cb2538b36 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 17 Dec 2024 08:12:52 +0100 Subject: [PATCH 106/115] Fixed an iteration bug. Using `nodes()` on an SDFG will only give us the control flow regions, but using `state` will give us also the nested states. I looked through my code and this seems to be the only places where they appear. This fixes the correlaton test, but the heat test still fails. --- dace/transformation/dataflow/map_fusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index a077246574..f107bb9641 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -1342,7 +1342,7 @@ def _compute_shared_data_in( # - The access node has output degree larger than 1 (input degrees larger # than one, will always be partitioned as shared anyway). prevously_seen_data: Set[str] = set() - for state in sdfg.nodes(): + for state in sdfg.states(): for access_node in state.data_nodes(): if access_node.data in shared_data: From abf739cc5e16fc70993cfcdb565a934c19c383ff Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 17 Dec 2024 08:42:44 +0100 Subject: [PATCH 107/115] Fixed the problem in the heat test. The issue was similar as before. When I computed the name of the intermediate transient then I used `sdfg.node_id(state)` to get the state ID. However, now if the state is part of these recursive control flow regions then this may not work, because the state is not a direct node of the SDFG. However, if I use `self.state_id` then it works, this is what the old MapFusion was doing. --- dace/transformation/dataflow/map_fusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index f107bb9641..2f5b0cead7 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -680,7 +680,7 @@ def handle_intermediate_set( # It will only have the shape `new_inter_shape` which is basically its # output within one Map iteration. # NOTE: The insertion process might generate a new name. - new_inter_name: str = f"__s{sdfg.node_id(state)}_n{state.node_id(out_edge.src)}{out_edge.src_conn}_n{state.node_id(out_edge.dst)}{out_edge.dst_conn}" + new_inter_name: str = f"__s{self.state_id}_n{state.node_id(out_edge.src)}{out_edge.src_conn}_n{state.node_id(out_edge.dst)}{out_edge.dst_conn}" # Now generate the intermediate data container. if len(new_inter_shape) == 0: From 2b17111b5d802a884b65c3c7ab1ea5bcde050e6e Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 17 Dec 2024 09:16:00 +0100 Subject: [PATCH 108/115] Added a new test. This tests dynamic Memlets inside producers; the original transformation fails on it. --- tests/transformations/mapfusion_test.py | 75 ++++++++++++++++++++++++- 1 file changed, 72 insertions(+), 3 deletions(-) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 6a1dc880a7..141d3ecde0 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -26,15 +26,18 @@ def apply_fusion( sdfg: SDFG, removed_maps: Union[int, None] = None, final_maps: Union[int, None] = None, + unspecific: bool = False, ) -> SDFG: """Applies the Map fusion transformation. The function checks that the number of maps has been reduced, it is also possible to specify the number of removed maps. It is also possible to specify the final number of maps. + If `unspecific` is set to `True` then the function will just apply the + transformation and not check if maps were removed at all. """ - num_maps_before = count_node(sdfg, nodes.MapEntry) org_sdfg = copy.deepcopy(sdfg) + num_maps_before = None if unspecific else count_node(sdfg, nodes.MapEntry) with dace.config.temporary_config(): dace.Config.set("optimizer", "match_exception", value=True) sdfg.apply_transformations_repeated( @@ -42,8 +45,11 @@ def apply_fusion( validate=True, validate_all=True ) - num_maps_after = count_node(sdfg, nodes.MapEntry) + if unspecific: + return sdfg + + num_maps_after = count_node(sdfg, nodes.MapEntry) has_processed = False if removed_maps is not None: has_processed = True @@ -629,7 +635,7 @@ def test_interstate_fusion(): ref_C = A + 30 ref_D = A + 26 - assert sdfg.apply_transformations_repeated(MapFusion, validate=True, validate_all=True) == 1 + apply_fusion(sdfg, removed_maps=1) assert sdfg.number_of_nodes() == 2 assert len([node for node in state1.data_nodes() if node.data == "B"]) == 1 @@ -1701,7 +1707,70 @@ def ref(A, B): assert np.allclose(arg_ref, arg_res) +def test_fusion_dynamic_producer(): + + def ref(A, B): + for i in range(10): + if B[i] < 0.5: + A[i] = 0.0 + for i in range(10): + B[i] = np.sin(A[i]) + + sdfg = dace.SDFG("fusion_dynamic_producer_sdfg") + state = sdfg.add_state(is_start_block=True) + for name in "AB": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + B_top, B_bottom, A = (state.add_access(name) for name in "BBA") + + state.add_mapped_tasklet( + "comp1", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("B[__i0]")}, + code="if __in1 < 0.5:\n\t__out = 0.0", + outputs={"__out": dace.Memlet("A[__i0]", dynamic=True)}, + input_nodes={B_top}, + output_nodes={A}, + external_edges=True, + ) + state.add_mapped_tasklet( + "comp2", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("A[__i0]")}, + code="__out = math.sin(__in1)", + outputs={"__out": dace.Memlet("B[__i0]")}, + input_nodes={A}, + output_nodes={B_bottom}, + external_edges=True, + ) + sdfg.validate() + + # In case dynamic Memlets should be handled, we specify `unspecific`, i.e. + # only validation tests are done. However, we run a verification step to see + # if the transformation did the right thing. + apply_fusion(sdfg, unspecific=True) + + args_ref = { + 'A': np.array(np.random.rand(11), dtype=np.float64, copy=True), + 'B': np.array(np.random.rand(11), dtype=np.float64, copy=True), + } + args_res = copy.deepcopy(args_ref) + + ref(**args_ref) + csdfg = sdfg.compile() + csdfg(**args_res) + for arg in args_ref.keys(): + arg_ref = args_ref[arg] + arg_res = args_res[arg] + assert np.allclose(arg_ref, arg_res) + + if __name__ == '__main__': + test_fusion_dynamic_producer() test_fusion_different_global_accesses() test_fusion_multiple_consumers() test_fusion_intermediate_different_access() From 3ab46d86d76b75f6bca1a81cf8fd6a82854248c8 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 17 Dec 2024 10:47:54 +0100 Subject: [PATCH 109/115] Added a flag to MapFusion that allows to consider everything as shared. --- dace/transformation/dataflow/map_fusion.py | 29 +++++++++++++++++----- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index 2f5b0cead7..42b4917505 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -56,13 +56,18 @@ class MapFusion(transformation.SingleStateTransformation): :param only_inner_maps: Only match Maps that are internal, i.e. inside another Map. :param only_toplevel_maps: Only consider Maps that are at the top. :param strict_dataflow: Which dataflow mode should be used, see above. + :param assume_always_shared: Assume that all intermediates are shared. :note: This transformation modifies more nodes than it matches. - :note: After the transformation has been applied simplify should be run to remove - some dead data flow, that was introduced to ensure validity. - :note: A `MapFusion` obejct can be initialized and be reused. However, - after new access nodes are added to any state, it is no longer valid - to use the object. + :note: An instance of MapFusion can be reused multiple times, with one exception. + Because the test if an intermediate can be removed or not is very expensive, + the transformation computes this information once in the beginning and then + caches it. However, the transformation lacks the means to detect if this data + has become out of data. Thus if new AccessNodes are added the cache is outdated + and the transformation should no longer be used. + :note: If `assume_always_shared` is `True` then the transformation will assume that + all intermediates are shared. This avoids the problems mentioned above with + the cache at the expense of the creation of dead dataflow. """ # Pattern Nodes @@ -87,6 +92,12 @@ class MapFusion(transformation.SingleStateTransformation): default=True, desc="If `True` then the transformation will ensure a more stricter data flow.", ) + assume_always_shared = properties.Property( + dtype=bool, + default=False, + desc="If `True` then all intermediates will be classified as shared.", + ) + # Maps SDFGs to the set of data that can not be removed, # because they transmit data _between states_, such data will be made 'shared'. # This variable acts as a cache, and is managed by 'is_shared_data()'. @@ -1291,7 +1302,7 @@ def is_shared_data( data: nodes.AccessNode, sdfg: dace.SDFG, ) -> bool: - """Tests if `data` is shared data, an can not be removed. + """Tests if `data` is shared data, an can not be removed from the SDFG. Interstate data is used to transmit data, this includes: * The data is referred in multiple states. @@ -1309,7 +1320,13 @@ def is_shared_data( :note: The function computes the this set once for every SDFG and then caches it. There is no mechanism to detect if the cache must be evicted. However, as long as no additional data is added to the SDFG, there is no problem. + :note: If `assume_always_shared` was set, then the function will always return `True`. """ + # This is the only point where we check for `assume_always_shared`. + if self.assume_always_shared: + return True + + # Check if the SDFG is known, if not scan it and compute the set. if sdfg not in self._shared_data: self._compute_shared_data_in(sdfg) return data.data in self._shared_data[sdfg] From b1fc9d1f3dd8d64f7c069d4bea1d9e9ffe812561 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 17 Dec 2024 12:05:01 +0100 Subject: [PATCH 110/115] Updated how the memlet adjustment works, this should be a bit more liberal. --- dace/transformation/dataflow/map_fusion.py | 46 ++++++++++++++-------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index 42b4917505..f5ada6ae0b 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -754,11 +754,15 @@ def handle_intermediate_set( for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children(include_self=False): producer_edge = producer_tree.edge - # Associate the (already existing) Memlet with the new data. - # TODO(phimuell): Improve the code below to remove the check. - assert producer_edge.data.data == inter_name - producer_edge.data.data = new_inter_name - + # In order to preserve the intrinsic direction of Memlets we only have to change + # the `.data` attribute of the producer Memlet if it refers to the old intermediate. + # If it refers to something different we keep it. Note that this case can only + # occur if the producer is an AccessNode. + if producer_edge.data.data == inter_name: + producer_edge.data.data = new_inter_name + + # Regardless of the intrinsic direction of the Memlet, the subset we care about + # is always `dst_subset`. if is_scalar: producer_edge.data.dst_subset = "0" elif producer_edge.data.dst_subset is not None: @@ -792,9 +796,6 @@ def handle_intermediate_set( out_conn_name = "OUT_" + in_conn_name[3:] for inner_edge in state.out_edges_by_connector(second_map_entry, out_conn_name): - # TODO(phimuell): Lift this restriction - assert inner_edge.data.data == inter_name # DIRECTION!! - # As for the producer side, we now read from a smaller array, # So we must offset them, we use the original edge for this. assert inner_edge.data.src_subset is not None @@ -805,11 +806,17 @@ def handle_intermediate_set( producer_offset=producer_offset, ) - # Now we create a new connection that instead reads from the new - # intermediate, instead of the old one. For this we use the - # old Memlet as template. However it is not fully initialized. + # Now create the memlet for the new consumer. To make sure that we get all attributes + # of the Memlet we make a deep copy of it. There is a tricky part here, we have to + # access `src_subset` however, this is only correctly set once it is put inside the + # SDFG. Furthermore, we have to make sure that the Memlet does not change its direction. + # i.e. that the association of `subset` and `other_subset` does not change. For this + # reason we only modify `.data` attribute of the Memlet if its name refers to the old + # intermediate. Furthermore, to play it safe, we only access the subset, `src_subset` + # after we have inserted it to the SDFG. new_inner_memlet = copy.deepcopy(inner_edge.data) - new_inner_memlet.data = new_inter_name + if inner_edge.data.data == inter_name: + new_inner_memlet.data = new_inter_name # Now we replace the edge from the SDFG. state.remove_edge(inner_edge) @@ -826,19 +833,26 @@ def handle_intermediate_set( if is_scalar: new_inner_memlet.subset = "0" elif new_inner_memlet.src_subset is not None: + # TODO(phimuell): Figuring out if `src_subset` is None is an error. new_inner_memlet.src_subset.offset(consumer_offset, negative=True) new_inner_memlet.src_subset.pop(squeezed_dims) # Now we have to make sure that all consumers are properly updated. for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children(include_self=False): - # TODO(phimuell): Lift this restriction - assert consumer_tree.edge.data.data == inter_name - consumer_edge = consumer_tree.edge - consumer_edge.data.data = new_inter_name + + # We only modify the data if the Memlet refers to the old intermediate data. + # We can not do this unconditionally, because it might change the intrinsic + # direction of a Memlet and then `src_subset` would at the next `try_initialize` + # be wrong. Note that this case only occurs if the destination is an AccessNode. + if consumer_edge.data.data == inter_name: + consumer_edge.data.data = new_inter_name + + # Now we have to adapt the subsets. if is_scalar: consumer_edge.data.src_subset = "0" elif consumer_edge.data.src_subset is not None: + # TODO(phimuell): Figuring out if `src_subset` is None is an error. consumer_edge.data.src_subset.offset(consumer_offset, negative=True) consumer_edge.data.src_subset.pop(squeezed_dims) From fdc6424c091520689f4d807d6eee01b05a954374 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 17 Dec 2024 12:05:52 +0100 Subject: [PATCH 111/115] Added a new test to check the memlet update. --- tests/transformations/mapfusion_test.py | 147 +++++++++++++++++++++++- 1 file changed, 142 insertions(+), 5 deletions(-) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 141d3ecde0..07c068df3f 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -27,6 +27,7 @@ def apply_fusion( removed_maps: Union[int, None] = None, final_maps: Union[int, None] = None, unspecific: bool = False, + apply_once: bool = False, ) -> SDFG: """Applies the Map fusion transformation. @@ -40,11 +41,18 @@ def apply_fusion( num_maps_before = None if unspecific else count_node(sdfg, nodes.MapEntry) with dace.config.temporary_config(): dace.Config.set("optimizer", "match_exception", value=True) - sdfg.apply_transformations_repeated( - MapFusion(strict_dataflow=True), - validate=True, - validate_all=True - ) + if apply_once: + sdfg.apply_transformations( + MapFusion(strict_dataflow=True), + validate=True, + validate_all=True + ) + else: + sdfg.apply_transformations_repeated( + MapFusion(strict_dataflow=True), + validate=True, + validate_all=True + ) if unspecific: return sdfg @@ -1769,7 +1777,136 @@ def ref(A, B): assert np.allclose(arg_ref, arg_res) +def test_fusion_intrinsic_memlet_direction(): + + def ref(A, B): + T = A + 10.0 + B[:] = np.sin(T) + + sdfg = dace.SDFG("fusion_dynamic_producer_sdfg") + state = sdfg.add_state(is_start_block=True) + + for name in "ATB": + sdfg.add_array( + name, + shape=(10, 11, 12), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["T"].transient = True + + for num in "12": + sdfg.add_scalar( + "t" + num, + dtype=dace.float64, + transient=True, + ) + + A, T, B, t1, t2 = (state.add_access(name) for name in ["A", "T", "B", "t1", "t2"]) + + tsklt1, me1, mx1 = state.add_mapped_tasklet( + "comp1", + map_ranges={ + "__i1": "0:10", + "__i2": "0:11", + "__i3": "0:12", + }, + inputs={"__in1": dace.Memlet("A[__i1, __i2, __i3]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("T[__i1, __i2, __i3]")}, + input_nodes={A}, + output_nodes={T}, + external_edges=True, + ) + + tsklt2, me2, mx2 = state.add_mapped_tasklet( + "comp2", + map_ranges={ + "__i1": "0:10", + "__i2": "0:11", + "__i3": "0:12", + }, + inputs={"__in1": dace.Memlet("T[__i1, __i2, __i3]")}, + code="__out = math.sin(__in1)", + outputs={"__out": dace.Memlet("B[__i1, __i2, __i3]")}, + input_nodes={T}, + output_nodes={B}, + external_edges=True, + ) + + for me in [me1, me2]: + dace.transformation.dataflow.MapExpansion.apply_to( + sdfg, + options={"inner_schedule": dace.ScheduleType.Default}, + map_entry=me, + ) + + # Now add a transient scalar at the output of `tsklt1`. + tsklt1_oedge = next(iter(state.out_edges(tsklt1))) + me1_inner = tsklt1_oedge.dst + state.add_edge( + tsklt1, "__out", + t1, None, + dace.Memlet("t1[0]"), + ) + state.add_edge( + t1, None, + me1_inner, tsklt1_oedge.dst_conn, + dace.Memlet("t1[0] -> [__i1, __i2, __i3]"), + ) + state.remove_edge(tsklt1_oedge) + tsklt1_oedge = None + + # Now add a transient scalar in the front of `tsklt2`. + tsklt2_iedge = next(iter(state.in_edges(tsklt2))) + me2_inner = tsklt2_iedge.src + state.add_edge( + me2_inner, tsklt2_iedge.src_conn, + t2, None, + dace.Memlet("t2[0] -> [__i1, __i2, __i3]"), + ) + state.add_edge( + t2, None, + tsklt2, "__in1", + dace.Memlet("t2[0]"), + ) + state.remove_edge(tsklt2_iedge) + tsklt2_iedge = None + sdfg.validate() + + # By Specifying `apply_once` we only perform one fusion, which will eliminate `T`. + # This is not efficient, we do this to make sure that the update of the Memlets + # has worked. + apply_fusion(sdfg, apply_once=True) + + for edge in state.edges(): + # There should be no edge, that references `T`. + assert edge.data.data != "T" + + # If an edge is connected to `t2` or `t1` then its data should refer to it. + # no other Memlet shall refer to them. + for t in [t1, t2]: + if edge.src is t or edge.dst is t: + assert edge.data.data == t.data + else: + assert edge.data.data != t.data + + args_ref = { + 'A': np.array(np.random.rand(10, 11, 12), dtype=np.float64, copy=True), + 'B': np.array(np.random.rand(10, 11, 12), dtype=np.float64, copy=True), + } + args_res = copy.deepcopy(args_ref) + + ref(**args_ref) + sdfg(**args_res) + for arg in args_ref.keys(): + arg_ref = args_ref[arg] + arg_res = args_res[arg] + assert np.allclose(arg_ref, arg_res) + + if __name__ == '__main__': + test_fusion_intrinsic_memlet_direction() test_fusion_dynamic_producer() test_fusion_different_global_accesses() test_fusion_multiple_consumers() From bcaed235881b78a3136bc008cf040f320e4d1a52 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 17 Jan 2025 08:35:52 +0100 Subject: [PATCH 112/115] Centralized the map fusion call in the testing. --- tests/transformations/mapfusion_test.py | 91 ++++++++++--------------- 1 file changed, 35 insertions(+), 56 deletions(-) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 07c068df3f..435b8a572c 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -28,6 +28,7 @@ def apply_fusion( final_maps: Union[int, None] = None, unspecific: bool = False, apply_once: bool = False, + strict_dataflow: bool = True, ) -> SDFG: """Applies the Map fusion transformation. @@ -36,23 +37,31 @@ def apply_fusion( number of maps. If `unspecific` is set to `True` then the function will just apply the transformation and not check if maps were removed at all. + If `strict_dataflow` is set to `True`, the default, then the function will perform + the fusion in strict dataflow mode. """ org_sdfg = copy.deepcopy(sdfg) num_maps_before = None if unspecific else count_node(sdfg, nodes.MapEntry) - with dace.config.temporary_config(): - dace.Config.set("optimizer", "match_exception", value=True) - if apply_once: - sdfg.apply_transformations( - MapFusion(strict_dataflow=True), - validate=True, - validate_all=True - ) - else: - sdfg.apply_transformations_repeated( - MapFusion(strict_dataflow=True), - validate=True, - validate_all=True - ) + + try: + with dace.config.temporary_config(): + dace.Config.set("optimizer", "match_exception", value=True) + if apply_once: + sdfg.apply_transformations( + MapFusion(strict_dataflow=strict_dataflow), + validate=True, + validate_all=True + ) + else: + sdfg.apply_transformations_repeated( + MapFusion(strict_dataflow=strict_dataflow), + validate=True, + validate_all=True + ) + except: + org_sdfg.view() + sdfg.view() + raise if unspecific: return sdfg @@ -927,12 +936,7 @@ def test_fusion_strict_dataflow_pointwise(): sdfg, state = _make_strict_dataflow_sdfg_pointwise(input_data="A") # However, if strict dataflow is disabled, then it will be able to fuse. - count = sdfg.apply_transformations_repeated( - MapFusion(strict_dataflow=False), - validate=True, - validate_all=True, - ) - assert count == 1 + apply_fusion(sdfg, removed_maps=1, strict_dataflow=False) def test_fusion_strict_dataflow_not_pointwise(): @@ -944,12 +948,7 @@ def test_fusion_strict_dataflow_not_pointwise(): # Because the dependency is not pointwise even disabling strict dataflow # will not make it work. - count = sdfg.apply_transformations_repeated( - MapFusion(strict_dataflow=False), - validate=True, - validate_all=True, - ) - assert count == 0 + apply_fusion(sdfg, removed_maps=0, strict_dataflow=False) def test_fusion_dataflow_intermediate(): @@ -958,12 +957,11 @@ def test_fusion_dataflow_intermediate(): intermediate_data="O", output_data="O", ) - count = sdfg.apply_transformations_repeated( - MapFusion(strict_dataflow=True), - validate=True, - validate_all=True, - ) - assert count == 0 + apply_fusion(sdfg, removed_maps=0, strict_dataflow=True) + + # Because the intermediate is also output of the second map it is not possible + # to fuse even without strict dataflow mode. + apply_fusion(sdfg, removed_maps=0, strict_dataflow=False) def test_fusion_dataflow_intermediate_2(): @@ -973,12 +971,7 @@ def test_fusion_dataflow_intermediate_2(): intermediate_data="A", output_data="O", ) - count = sdfg.apply_transformations_repeated( - MapFusion(strict_dataflow=True), - validate=True, - validate_all=True, - ) - assert count == 1 + apply_fusion(sdfg, removed_maps=1, strict_dataflow=True) map_exit = next(iter(node for node in state.nodes() if isinstance(node, nodes.MapExit))) assert state.out_degree(map_exit) == 2 assert {"A", "O"} == {edge.dst.data for edge in state.out_edges(map_exit) if isinstance(edge.dst, nodes.AccessNode)} @@ -1010,20 +1003,11 @@ def test_fusion_dataflow_intermediate_downstream(): ) sdfg.validate() - count = sdfg.apply_transformations_repeated( - MapFusion(strict_dataflow=True), - validate=True, - validate_all=True, - ) - assert count == 0 + apply_fusion(sdfg, removed_maps=0, strict_dataflow=True) + sdfg.view() # However without strict dataflow, the merge is possible. - count = sdfg.apply_transformations_repeated( - MapFusion(strict_dataflow=False), - validate=True, - validate_all=True, - ) - assert count == 1 + apply_fusion(sdfg, removed_maps=1, strict_dataflow=False) assert state.in_degree(output_1) == 1 assert state.out_degree(output_1) == 1 assert all(isinstance(edge.src, nodes.MapExit) for edge in state.in_edges(output_1)) @@ -1115,12 +1099,7 @@ def test_fusion_non_strict_dataflow_implicit_dependency(): ) sdfg.validate() - count = sdfg.apply_transformations_repeated( - MapFusion(strict_dataflow=False), - validate=True, - validate_all=True, - ) - assert count == 0 + apply_fusion(sdfg, removed_maps=0, strict_dataflow=False) def _make_inner_conflict_shared_scalar( From 243611d6b60b9f2e1aa9d9dd2375a48337082df4 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 17 Jan 2025 08:57:13 +0100 Subject: [PATCH 113/115] Added a test that ensures that no cycles would be created. --- tests/transformations/mapfusion_test.py | 77 +++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 435b8a572c..ed7e5666ca 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -1884,6 +1884,82 @@ def ref(A, B): assert np.allclose(arg_ref, arg_res) +def _make_possible_cycle_if_fuesed_sdfg() -> Tuple[dace.SDFG, nodes.MapExit, nodes.AccessNode, nodes.MapEntry]: + """Generate an SDFG that if two maps would be fused a cycle would be created. + + Essentially tests if the MapFusion detects this special case. + """ + sdfg = dace.SDFG("possible_cycle_if_fuesed_sdfg") + state = sdfg.add_state(is_start_block=True) + + names = ["A", "B", "T", "U", "V"] + for name in names: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=True, + ) + sdfg.arrays["A"].transient = False + sdfg.arrays["B"].transient = False + + A, B, T, U, V = (state.add_access(name) for name in names) + + _, _, first_map_exit = state.add_mapped_tasklet( + "map1", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("A[__i]")}, + code="__out1 = __in + 10\n__out2 = __in - 10", + outputs={ + "__out1": dace.Memlet("T[__i]"), + "__out2": dace.Memlet("U[__i]"), + }, + input_nodes={A}, + output_nodes={T, U}, + external_edges=True, + ) + + state.add_mapped_tasklet( + "map2", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("U[__i]")}, + code="__out = math.sin(__in)", + outputs={"__out": dace.Memlet("V[__i]")}, + input_nodes={U}, + output_nodes={V}, + external_edges=True, + ) + + _, second_map_entry, _ = state.add_mapped_tasklet( + "map3", + map_ranges={"__i": "0:10"}, + inputs={ + "__in1": dace.Memlet("T[__i]"), + "__in2": dace.Memlet("V[__i]"), + }, + code="__out = __in1 + __in2", + outputs={"__out": dace.Memlet("B[__i]")}, + input_nodes={T, V}, + output_nodes={B}, + external_edges=True, + ) + sdfg.validate() + + return sdfg, first_map_exit, T, second_map_entry + + +def test_possible_cycle_if_fuesed_sdfg(): + sdfg, first_map_exit, array, second_map_entry = _make_possible_cycle_if_fuesed_sdfg() + + would_transformation_apply = MapFusion.can_be_applied_to( + sdfg, + first_map_exit=first_map_exit, + array=array, + second_map_entry=second_map_entry, + ) + assert not would_transformation_apply + + if __name__ == '__main__': test_fusion_intrinsic_memlet_direction() test_fusion_dynamic_producer() @@ -1917,3 +1993,4 @@ def ref(A, B): test_different_offsets() test_inner_map_dependency() test_inner_map_dependency_resolved() + test_possible_cycle_if_fuesed_sdfg() From c53f939682250995fbbe45c033185c2d1fccc0dc Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 17 Jan 2025 13:13:39 +0100 Subject: [PATCH 114/115] Found a case that was not handled. The transformation checks if the first map satisifes the data dependencies of the second map. For this is looks at the writes and reads of the intermediate. It also checks if, a data container is used as input of the first and as output of the second map, if the access is pointwise and can be fused. Furthermore, it was allowed that the intermediate is also used as input to the first map. However, in that particular case, it was not checked if the the reads and writes of the first map alone to the intermediate are valid. I.e. it could read read `A[i]` but write `A[i+1]` which would cause problems (note that this usage is botherline legal anyway. This commit adds a check to make sure that this is not the case by enforcing if a data container is used as input and output of the first map and also as intermediate node then its read must be pointwise. Note that if it is not an intermediate node, i.e. not also read by the second map, then this rule does not apply. NOTE: It is forbidden that the intermediate is used as intermediate and output of the second map. --- dace/transformation/dataflow/map_fusion.py | 63 +++++++++++++++++++--- 1 file changed, 57 insertions(+), 6 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index f5ada6ae0b..dccebe9727 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -1144,14 +1144,21 @@ def has_read_write_dependency( ) -> bool: """Test if there is a read write dependency between the two maps to be fused. - The function checks two different things. + The function checks three different things. * The function will make sure that there is no read write dependency between the input and output of the fused maps. For that it will inspect the - respective subsets. + respective subsets of the inputs of the MapEntry of the first and the + outputs of the MapExit node of the second map. * The second part partially checks the intermediate nodes, it mostly ensures - that there are not views and that they are not used as inputs or outputs - at the same time. However, the function will not check for read write - conflicts in this set, this is done in the partition function. + that there are not views and that they are not used as output of the + combined map. Note that it is allowed that an intermediate node is also + an input to the first map. + * In case an intermediate node, is also used as input node of the first map, + it is forbidden that the data is used as output of the second map, the + function will do additional checks. This is needed as the partition function + only checks the data consumption of the second map can be satisfied by the + data production of the first map, it ignores any potential reads made by + the first map's MapEntry. :return: `True` if there is a conflict between the maps that can not be handled. If there is no conflict or if the conflict can be handled `False` is returned. @@ -1195,6 +1202,9 @@ def has_read_write_dependency( real_read_map_1, real_write_map_1, real_read_map_2, real_write_map_2 = resolved_sets # We do not allow that the first and second map each write to the same data. + # This essentially ensures that an intermediate can not be used as output of + # the second map at the same time. It is actually stronger as it does not + # take their role into account. if not real_write_map_1.isdisjoint(real_write_map_2): return True @@ -1204,7 +1214,8 @@ def has_read_write_dependency( exchange_names: Set[str] = set(write_map_1.keys()).intersection(read_map_2.keys()) exchange_nodes: Set[nodes.AccessNode] = set(write_map_1.values()).intersection(read_map_2.values()) - # If the number are different then a data is accessed through multiple nodes. + # If the number are different then a data is accessed through different + # AccessNodes. We could analyse this, but we will consider this as a data race. if len(exchange_names) != len(exchange_nodes): return True assert all(exchange_node.data in exchange_names for exchange_node in exchange_nodes) @@ -1234,6 +1245,46 @@ def has_read_write_dependency( if not fused_inout_data_names.isdisjoint(exchange_names): return True + # While it is forbidden that a data container, used as intermediate, is also + # used as output of the second map. It is allowed that the data container + # is used as intermediate and as input of the first map. The partition only + # checks that the data dependencies are mean, i.e. what is read by the second + # map is also computed (written to the intermediate) it does not take into + # account the first map's read to the data container. + # To make an example: The partition function will make sure that if the + # second map reads index `i` from the intermediate that the first map writes + # to that index. But it will not care if the first map reads (through its + # MapEntry) index `i + 1`. In order to be valid me must ensure that the first + # map's reads and writes to the intermediate are pointwise. + # Note that we only have to make this check if it is also an intermediate node. + # Because if it is not read by the second map it is not a problem as the node + # will end up as an pure output node anyway. + read_write_map_1 = set(read_map_1.keys()).intersection(write_map_1.keys()) + datas_to_inspect = read_write_map_1.intersection(exchange_names) + for data_to_inspect in datas_to_inspect: + + # Now get all subsets of the data container that the first map reads + # from or writes to and check if they are pointwise. + all_subsets: List[subsets.Subset] = [] + all_subsets.extend( + self.find_subsets( + node=read_map_1[data_to_inspect], + scope_node=first_map_entry, + state=state, + sdfg=sdfg, + param_repl=None, + )) + all_subsets.extend( + self.find_subsets( + node=write_map_1[data_to_inspect], + scope_node=first_map_exit, + state=state, + sdfg=sdfg, + param_repl=None, + )) + if not self.test_if_subsets_are_point_wise(all_subsets): + return True + # If there is no intersection between the input and output data, then we can # we have nothing to check. if len(fused_inout_data_names) == 0: From a3842f97ebba4a0bec8c2783c02c129acc244524 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 17 Jan 2025 11:35:07 +0100 Subject: [PATCH 115/115] Added more tests to the map fusion and refined some others. --- tests/transformations/mapfusion_test.py | 71 +++++++++++++++++++++---- 1 file changed, 60 insertions(+), 11 deletions(-) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index ed7e5666ca..1e7678c3aa 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -965,7 +965,11 @@ def test_fusion_dataflow_intermediate(): def test_fusion_dataflow_intermediate_2(): - # Because `A` is not also output transformation applies. + # The transformation applies for two reasons, first reading and writing `A` + # is pointwise. Furthermore, there is no further access to `A` after the + # intermediate node. Note that if the second map would also have an output + # that refers to `A` then the transformation would not apply regardless + # of the strict dataflow mode. sdfg, state = _make_strict_dataflow_sdfg_pointwise( input_data="A", intermediate_data="A", @@ -977,6 +981,21 @@ def test_fusion_dataflow_intermediate_2(): assert {"A", "O"} == {edge.dst.data for edge in state.out_edges(map_exit) if isinstance(edge.dst, nodes.AccessNode)} +def test_fusion_dataflow_intermediate_3(): + # This is exactly the same situation as in `test_fusion_dataflow_intermediate_2()` + # with the exception that now the access to `A` is no longer pointwise, thus + # the transformation does not apply. Note that this SDFG is wrong, it is only + # here to show that the case is detected. + sdfg, state = _make_strict_dataflow_sdfg_pointwise( + input_data="A", + intermediate_data="A", + output_data="O", + input_read="9 - __i0", + output_write="__i0", + ) + apply_fusion(sdfg, removed_maps=0, strict_dataflow=True) + + def test_fusion_dataflow_intermediate_downstream(): # Because the intermediate `T` is used downstream again, # the transformation can not apply. @@ -1001,10 +1020,18 @@ def test_fusion_dataflow_intermediate_downstream(): input_nodes={output_1}, external_edges=True, ) + + # Make another state where `T` is written back, such that it is not dead data flow. + state2 = sdfg.add_state_after(state) + sdfg.add_datadesc("output_2", sdfg.arrays["output_1"].clone()) + state2.add_nedge( + state2.add_access("T"), + state2.add_access("output_2"), + sdfg.make_array_memlet("T"), + ) sdfg.validate() apply_fusion(sdfg, removed_maps=0, strict_dataflow=True) - sdfg.view() # However without strict dataflow, the merge is possible. apply_fusion(sdfg, removed_maps=1, strict_dataflow=False) @@ -1238,8 +1265,10 @@ def test_inner_map_dependency_resolved(): apply_fusion(sdfg, removed_maps=1, final_maps=1) -def _impl_fusion_intermediate_different_access(modified_shape: bool): - +def _impl_fusion_intermediate_different_access( + modified_shape: bool, + traditional_memlet_direction: bool +): def ref(A, B): T = np.zeros((A.shape[0] + 1, 2)) for i in range(A.shape[0]): @@ -1338,14 +1367,26 @@ def ref(A, B): ) ) + temp_subset = ( + "0, 0:2" + if modified_shape + else "0:2" + ) + T_subset = "__i0 + 1, 0:2" + + if traditional_memlet_direction: + mem_data = "T" + mem_subset = T_subset + mem_other_subset = temp_subset + else: + mem_data = "temp" + mem_subset = temp_subset + mem_other_subset = T_subset + state.add_edge( temp, None, mx1, "IN_temp", - dace.Memlet( - "temp[0, 0:2] -> [__i0 + 1, 0:2]" - if modified_shape - else "temp[0:2] -> [__i0 + 1, 0:2]" - ) + dace.Memlet(f"{mem_data}[{mem_subset}] -> [{mem_other_subset}]") ) state.add_edge( mx1, "OUT_temp", @@ -1384,11 +1425,19 @@ def ref(A, B): def test_fusion_intermediate_different_access(): - _impl_fusion_intermediate_different_access(modified_shape=False) + _impl_fusion_intermediate_different_access(modified_shape=False, traditional_memlet_direction=False) + + +def test_fusion_intermediate_different_access_2(): + _impl_fusion_intermediate_different_access(modified_shape=False, traditional_memlet_direction=True) def test_fusion_intermediate_different_access_mod_shape(): - _impl_fusion_intermediate_different_access(modified_shape=True) + _impl_fusion_intermediate_different_access(modified_shape=True, traditional_memlet_direction=False) + + +def test_fusion_intermediate_different_access_mod_shape_2(): + _impl_fusion_intermediate_different_access(modified_shape=True, traditional_memlet_direction=True) @pytest.mark.skip(reason="This feature is not yet fully supported.")