diff --git a/dace/transformation/dataflow/redundant_array.py b/dace/transformation/dataflow/redundant_array.py index 7b241ff9cd..6c6768f254 100644 --- a/dace/transformation/dataflow/redundant_array.py +++ b/dace/transformation/dataflow/redundant_array.py @@ -7,9 +7,10 @@ from typing import Dict, List, Optional, Tuple import networkx as nx +from dace.subsets import Range, Indices, SubrangeMapper from networkx.exception import NetworkXError, NodeNotFound -from dace import data, dtypes +from dace import data, dtypes, Memlet from dace import memlet as mm from dace import subsets, symbolic from dace.config import Config @@ -24,7 +25,7 @@ def _validate_subsets(edge: graph.MultiConnectorEdge, arrays: Dict[str, data.Data], src_name: str = None, - dst_name: str = None) -> Tuple[subsets.Subset]: + dst_name: str = None) -> Tuple[subsets.Subset, subsets.Subset]: """ Extracts and validates src and dst subsets from the edge. """ # Find src and dst names @@ -42,6 +43,7 @@ def _validate_subsets(edge: graph.MultiConnectorEdge, if not src_subset and not dst_subset: # NOTE: This should never happen raise NotImplementedError('Neither source nor destination subsets are defined') + return src_subset, dst_subset # NOTE: If any of the subsets is None, it means that we proceed in # experimental mode. The base case here is that we just copy the other # subset. However, if we can locate the other array, we check the @@ -499,182 +501,65 @@ def _is_reshaping_memlet( return True - def apply(self, graph, sdfg): - in_array = self.in_array - out_array = self.out_array - in_desc = sdfg.arrays[in_array.data] - out_desc = sdfg.arrays[out_array.data] + def apply(self, graph: SDFGState, sdfg: SDFG): + # The pattern is A ---> B, and we want to remove A + A, B = self.in_array, self.out_array # 1. Get edge e1 and extract subsets for arrays A and B - e1 = graph.edges_between(in_array, out_array)[0] - a1_subset, b_subset = _validate_subsets(e1, sdfg.arrays) - - # View connected to a view: simple case - if (isinstance(in_desc, data.View) and isinstance(out_desc, data.View)): - simple_case = True - for e in graph.in_edges(in_array): - if e.data.dst_subset is not None and a1_subset != e.data.dst_subset: - simple_case = False - break - if simple_case: - for e in graph.in_edges(in_array): - for e2 in graph.memlet_tree(e): - if e2 is e: - continue - if e2.data.data == in_array.data: - e2.data.data = out_array.data - new_memlet = copy.deepcopy(e.data) - if new_memlet.data == in_array.data: - new_memlet.data = out_array.data - new_memlet.dst_subset = b_subset - graph.add_edge(e.src, e.src_conn, out_array, e.dst_conn, new_memlet) - graph.remove_node(in_array) - try: - if in_array.data in sdfg.arrays: - sdfg.remove_data(in_array.data) - except ValueError: # Used somewhere else - pass - return - - # Find extraneous A or B subset dimensions - a_dims_to_pop = [] - b_dims_to_pop = [] - bset = b_subset - popped = [] - if a1_subset and b_subset and a1_subset.dims() != b_subset.dims(): - a_size = a1_subset.size_exact() - b_size = b_subset.size_exact() - if a1_subset.dims() > b_subset.dims(): - a_dims_to_pop = find_dims_to_pop(a_size, b_size) - else: - b_dims_to_pop = find_dims_to_pop(b_size, a_size) - bset, popped = pop_dims(b_subset, b_dims_to_pop) - - from dace.libraries.standard import Reduce - reduction = False - for e in graph.in_edges(in_array): - if isinstance(e.src, Reduce) or (isinstance(e.src, nodes.NestedSDFG) - and len(in_desc.shape) != len(out_desc.shape)): - reduction = True - - # If: - # 1. A reduce node is involved; or - # 2. A NestedSDFG node is involved and the arrays have different dimensionality; or - # 3. The memlet does not cover the removed array; or - # 4. Dimensions are mismatching (all dimensions are popped); - # create a view. - if ( - reduction - or len(a_dims_to_pop) == len(in_desc.shape) - or any(m != a for m, a in zip(a1_subset.size(), in_desc.shape)) - ): - self._make_view(sdfg, graph, in_array, out_array, e1, b_subset, b_dims_to_pop) - return in_array - - # TODO: Fix me. - # As described in [issue 1595](https://github.com/spcl/dace/issues/1595) the - # transformation is unable to handle certain cases of reshaping Memlets - # correctly and fixing this case has proven rather difficult. In a first - # attempt the case of reshaping Memlets was forbidden (in the - # `can_be_applied()` method), however, this caused other (useful) cases to - # fail. For that reason such Memlets are transformed to Views. - # This is a fix and it should be addressed. - if self._is_reshaping_memlet(graph=graph, edge=e1): - self._make_view(sdfg, graph, in_array, out_array, e1, b_subset, b_dims_to_pop) - return in_array - - # Validate that subsets are composable. If not, make a view - try: - for e2 in graph.in_edges(in_array): - path = graph.memlet_tree(e2) - wcr = e1.data.wcr - wcr_nonatomic = e1.data.wcr_nonatomic - for e3 in path: - # 2-a. Extract subsets for array B and others - other_subset, a3_subset = _validate_subsets(e3, sdfg.arrays, dst_name=in_array.data) - # 2-b. Modify memlet to match array B. - dname = out_array.data - src_is_data = False - a3_subset.offset(a1_subset, negative=True) - - if a3_subset and a_dims_to_pop: - aset, _ = pop_dims(a3_subset, a_dims_to_pop) - else: - aset = a3_subset - - compose_and_push_back(bset, aset, b_dims_to_pop, popped) - except (ValueError, NotImplementedError): - self._make_view(sdfg, graph, in_array, out_array, e1, b_subset, b_dims_to_pop) - print(f"CREATED VIEW(2): {in_array}") - return in_array - - # 2. Iterate over the e2 edges and traverse the memlet tree - for e2 in graph.in_edges(in_array): - path = graph.memlet_tree(e2) - wcr = e1.data.wcr - wcr_nonatomic = e1.data.wcr_nonatomic - for e3 in path: - # 2-a. Extract subsets for array B and others - other_subset, a3_subset = _validate_subsets(e3, sdfg.arrays, dst_name=in_array.data) - # 2-b. Modify memlet to match array B. - dname = out_array.data - src_is_data = False - a3_subset.offset(a1_subset, negative=True) - - if a3_subset and a_dims_to_pop: - aset, _ = pop_dims(a3_subset, a_dims_to_pop) - else: - aset = a3_subset - - dst_subset = compose_and_push_back(bset, aset, b_dims_to_pop, popped) - # NOTE: This fixes the following case: - # Tasklet ----> A[subset] ----> ... -----> A - # Tasklet is not data, so it doesn't have an other subset. - if isinstance(e3.src, nodes.AccessNode): - if e3.src.data == out_array.data: - dname = e3.src.data - src_is_data = True - src_subset = other_subset - else: - src_subset = None - - subset = src_subset if src_is_data else dst_subset - other_subset = dst_subset if src_is_data else src_subset - e3.data.data = dname - e3.data.subset = subset - e3.data.other_subset = other_subset - wcr = wcr or e3.data.wcr - wcr_nonatomic = wcr_nonatomic or e3.data.wcr_nonatomic - e3.data.wcr = wcr - e3.data.wcr_nonatomic = wcr_nonatomic - - # 2-c. Remove edge and add new one - graph.remove_edge(e2) - e2.data.wcr = wcr - e2.data.wcr_nonatomic = wcr_nonatomic - graph.add_edge(e2.src, e2.src_conn, out_array, e2.dst_conn, e2.data) - - # 2-d. Fix strides in nested SDFGs - if in_desc.strides != out_desc.strides: - sources = [] - if path.downwards: - sources = [path.root().edge] - else: - sources = [e for e in path.leaves()] - for source_edge in sources: - if not isinstance(source_edge.src, nodes.NestedSDFG): - continue - conn = source_edge.src_conn - inner_desc = source_edge.src.sdfg.arrays[conn] - inner_desc.strides = out_desc.strides - - # Finally, remove in_array node - graph.remove_node(in_array) - try: - if in_array.data in sdfg.arrays: - sdfg.remove_data(in_array.data) - except ValueError: # Already in use (e.g., with Views) - pass + e_ab = graph.edges_between(A, B) + assert len(e_ab) == 1 + e_ab = e_ab[0] + print(e_ab) + a_subset, b_subset = _validate_subsets(e_ab, sdfg.arrays) + # Other cases should have been handled in `can_be_applied()`. + assert isinstance(a_subset, Range) or isinstance(a_subset, Indices) + assert isinstance(b_subset, Range) or isinstance(b_subset, Indices) + # And this should be self-evident. + assert a_subset.volume_exact() == b_subset.volume_exact() + + for ie in graph.in_edges(A): + # The pattern is now: C -(ie)-> A ---> B + path = graph.memlet_tree(ie) + for pe in path: + # The pattern is now: C -(pe)-> C1 ---> ... ---> A ---> B + print('PE:', pe) + c_subset, a0_subset = _validate_subsets(pe, sdfg.arrays, dst_name=A.data) + print('c, a0:', c_subset, a0_subset) + if a0_subset is None: + continue + + # Other cases should have been handled already in `can_be_applied()`. + assert c_subset is None or isinstance(c_subset, Range) or isinstance(c_subset, Indices) + if c_subset is not None: + assert c_subset.volume_exact() == a0_subset.volume_exact() + assert isinstance(a0_subset, Range) or isinstance(a0_subset, Indices) + print('SUBS:', a_subset, '|', a0_subset) + assert a_subset.dims() == a0_subset.dims() + # assert a_subset.covers_precise(a0_subset) + # assert all(b0 >= b and (b0 - b) % s == 0 and s0 % s == 0 + # for (b, e, s), (b0, e0, s0) in zip(a_subset.ndrange(), a0_subset.ndrange())) + + # Find out where `a0_subset` maps to, given that `a_subset` precisely maps to `b_subset`. + # `reshapr` describes how `a_subset` maps to `b_subset`. + reshapr = SubrangeMapper(a_subset, b_subset) + # `b0_subset` is the mapping for `a0_subset`. + b0_subset = reshapr.map(a0_subset) + print(a_subset, b_subset) + print(a0_subset, b0_subset) + assert isinstance(b0_subset, Range) or isinstance(b0_subset, Indices) + assert b0_subset.volume_exact() == a0_subset.volume_exact() + + # Now we can replace the path: C -(pe)-> C1 ---> ... ---> A ---> B + # with an equivalent path: C -(pe)-> C1 ---> ... ---> B + dst, dst_conn = (B, None) if pe.dst is A else (pe.dst, pe.dst_conn) + print('dst:', pe.src, dst, dst_conn) + print('mem:', B.data, b0_subset, c_subset) + e = graph.add_edge(pe.src, pe.src_conn, dst, dst_conn, + memlet=Memlet(data=B.data, subset=b0_subset, other_subset=c_subset)) + print('e:', e) + graph.remove_edge(pe) + graph.remove_node(A) + sdfg.remove_data(A.data) class RedundantSecondArray(pm.SingleStateTransformation):