diff --git a/dace/transformation/subgraph/stencil_tiling.py b/dace/transformation/subgraph/stencil_tiling.py index c9c3e9afd4..f6c60eea8f 100644 --- a/dace/transformation/subgraph/stencil_tiling.py +++ b/dace/transformation/subgraph/stencil_tiling.py @@ -2,34 +2,27 @@ """ This module contains classes and functions that implement the orthogonal stencil tiling transformation. """ -import math +import itertools +import warnings +from collections import defaultdict +from copy import deepcopy as dcpy import dace -from dace import dtypes, registry, symbolic +import dace.subsets as subsets +import dace.symbolic as symbolic +from dace import dtypes from dace.properties import make_properties, Property, ShapeProperty from dace.sdfg import nodes -from dace.transformation import transformation from dace.sdfg.propagation import _propagate_node - -from dace.transformation.dataflow.map_for_loop import MapToForLoop -from dace.transformation.dataflow.map_expansion import MapExpansion +from dace.transformation import transformation from dace.transformation.dataflow.map_collapse import MapCollapse +from dace.transformation.dataflow.map_expansion import MapExpansion +from dace.transformation.dataflow.map_for_loop import MapToForLoop from dace.transformation.dataflow.strip_mining import StripMining -from dace.transformation.interstate.loop_unroll import LoopUnroll from dace.transformation.interstate.loop_detection import DetectLoop -from dace.transformation.subgraph import SubgraphFusion - -from copy import deepcopy as dcpy - -import dace.subsets as subsets -import dace.symbolic as symbolic - -import itertools -import warnings - -from collections import defaultdict - +from dace.transformation.interstate.loop_unroll import LoopUnroll from dace.transformation.subgraph import helpers +from dace.transformation.subgraph import subgraph_fusion @make_properties @@ -51,7 +44,7 @@ class StencilTiling(transformation.SubgraphTransformation): prefix = Property(dtype=str, default="stencil", desc="Prefix for new inner tiled range symbols") - strides = ShapeProperty(dtype=tuple, default=(1, ), desc="Tile stride") + strides = ShapeProperty(dtype=tuple, default=(1,), desc="Tile stride") schedule = Property(dtype=dace.dtypes.ScheduleType, default=dace.dtypes.ScheduleType.Default, @@ -200,13 +193,13 @@ def can_be_applied(sdfg, subgraph) -> bool: # get intermediate_nodes, out_nodes from SubgraphFusion Transformation try: - node_config = get_adjacent_nodes(sdfg, graph, map_entries) + node_config = subgraph_fusion.get_adjacent_nodes(graph, map_entries) (_, intermediate_nodes, out_nodes) = node_config except NotImplementedError: return False # 1.4: check topological feasibility - if not check_topo_feasibility(sdfg, graph, map_entries, intermediate_nodes, out_nodes): + if not subgraph_fusion.check_topo_feasibility(graph, map_entries, intermediate_nodes, out_nodes): return False # 1.5 nodes that are both intermediate and out nodes # are not supported in StencilTiling @@ -215,8 +208,8 @@ def can_be_applied(sdfg, subgraph) -> bool: # 1.6 check that we only deal with compressible transients - subgraph_contains_data = determine_compressible_nodes(sdfg, graph, intermediate_nodes, - map_entries, map_exits) + subgraph_contains_data = subgraph_fusion.determine_compressible_nodes(sdfg, graph, intermediate_nodes, + map_entries, map_exits) if any([s == False for s in subgraph_contains_data.values()]): return False @@ -264,8 +257,8 @@ def can_be_applied(sdfg, subgraph) -> bool: for i, (p_subset, c_subset) in enumerate(zip(parent_coverage, children_coverage)): # transform into subset - p_subset = subsets.Range((p_subset, )) - c_subset = subsets.Range((c_subset, )) + p_subset = subsets.Range((p_subset,)) + c_subset = subsets.Range((c_subset,)) # get associated parameter in memlet params1 = symbolic.symlist(memlets[map_entry][1][data_name][i]).keys() @@ -292,7 +285,7 @@ def can_be_applied(sdfg, subgraph) -> bool: except KeyError: return False - #parameter mapping must be the same + # parameter mapping must be the same if param_parent_coverage != param_children_coverage: return False @@ -394,7 +387,7 @@ def apply(self, sdfg): for data_name, ranges in local_ranges.items(): for param, r in zip(variable_mapping[data_name], ranges): # create new range from this subset and assign - rng = subsets.Range((r, )) + rng = subsets.Range((r,)) if param: inferred_ranges[map_entry][param] = subsets.union(inferred_ranges[map_entry][param], rng) @@ -457,9 +450,9 @@ def apply(self, sdfg): reference_range_current = self.reference_range[param] min_diff = symbolic.SymExpr(reference_range_current.min_element()[0] \ - - target_range_current.min_element()[0]) + - target_range_current.min_element()[0]) max_diff = symbolic.SymExpr(target_range_current.max_element()[0] \ - - reference_range_current.max_element()[0]) + - reference_range_current.max_element()[0]) try: min_diff = symbolic.evaluate(min_diff, {})