Skip to content

Commit

Permalink
Forgot to update the references to subgraph fusion, fixing now
Browse files Browse the repository at this point in the history
+ Removing unnecessary imports.
  • Loading branch information
pratyai committed Oct 17, 2024
1 parent f355a6b commit 490b415
Showing 1 changed file with 23 additions and 30 deletions.
53 changes: 23 additions & 30 deletions dace/transformation/subgraph/stencil_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,27 @@
""" This module contains classes and functions that implement the orthogonal
stencil tiling transformation. """

import math
import itertools
import warnings
from collections import defaultdict
from copy import deepcopy as dcpy

import dace
from dace import dtypes, registry, symbolic
import dace.subsets as subsets
import dace.symbolic as symbolic
from dace import dtypes
from dace.properties import make_properties, Property, ShapeProperty
from dace.sdfg import nodes
from dace.transformation import transformation
from dace.sdfg.propagation import _propagate_node

from dace.transformation.dataflow.map_for_loop import MapToForLoop
from dace.transformation.dataflow.map_expansion import MapExpansion
from dace.transformation import transformation
from dace.transformation.dataflow.map_collapse import MapCollapse
from dace.transformation.dataflow.map_expansion import MapExpansion
from dace.transformation.dataflow.map_for_loop import MapToForLoop
from dace.transformation.dataflow.strip_mining import StripMining
from dace.transformation.interstate.loop_unroll import LoopUnroll
from dace.transformation.interstate.loop_detection import DetectLoop
from dace.transformation.subgraph import SubgraphFusion

from copy import deepcopy as dcpy

import dace.subsets as subsets
import dace.symbolic as symbolic

import itertools
import warnings

from collections import defaultdict

from dace.transformation.interstate.loop_unroll import LoopUnroll
from dace.transformation.subgraph import helpers
from dace.transformation.subgraph import subgraph_fusion


@make_properties
Expand All @@ -51,7 +44,7 @@ class StencilTiling(transformation.SubgraphTransformation):

prefix = Property(dtype=str, default="stencil", desc="Prefix for new inner tiled range symbols")

strides = ShapeProperty(dtype=tuple, default=(1, ), desc="Tile stride")
strides = ShapeProperty(dtype=tuple, default=(1,), desc="Tile stride")

schedule = Property(dtype=dace.dtypes.ScheduleType,
default=dace.dtypes.ScheduleType.Default,
Expand Down Expand Up @@ -200,13 +193,13 @@ def can_be_applied(sdfg, subgraph) -> bool:

# get intermediate_nodes, out_nodes from SubgraphFusion Transformation
try:
node_config = get_adjacent_nodes(sdfg, graph, map_entries)
node_config = subgraph_fusion.get_adjacent_nodes(graph, map_entries)
(_, intermediate_nodes, out_nodes) = node_config
except NotImplementedError:
return False

# 1.4: check topological feasibility
if not check_topo_feasibility(sdfg, graph, map_entries, intermediate_nodes, out_nodes):
if not subgraph_fusion.check_topo_feasibility(graph, map_entries, intermediate_nodes, out_nodes):
return False
# 1.5 nodes that are both intermediate and out nodes
# are not supported in StencilTiling
Expand All @@ -215,8 +208,8 @@ def can_be_applied(sdfg, subgraph) -> bool:

# 1.6 check that we only deal with compressible transients

subgraph_contains_data = determine_compressible_nodes(sdfg, graph, intermediate_nodes,
map_entries, map_exits)
subgraph_contains_data = subgraph_fusion.determine_compressible_nodes(sdfg, graph, intermediate_nodes,
map_entries, map_exits)
if any([s == False for s in subgraph_contains_data.values()]):
return False

Expand Down Expand Up @@ -264,8 +257,8 @@ def can_be_applied(sdfg, subgraph) -> bool:
for i, (p_subset, c_subset) in enumerate(zip(parent_coverage, children_coverage)):

# transform into subset
p_subset = subsets.Range((p_subset, ))
c_subset = subsets.Range((c_subset, ))
p_subset = subsets.Range((p_subset,))
c_subset = subsets.Range((c_subset,))

# get associated parameter in memlet
params1 = symbolic.symlist(memlets[map_entry][1][data_name][i]).keys()
Expand All @@ -292,7 +285,7 @@ def can_be_applied(sdfg, subgraph) -> bool:
except KeyError:
return False

#parameter mapping must be the same
# parameter mapping must be the same
if param_parent_coverage != param_children_coverage:
return False

Expand Down Expand Up @@ -394,7 +387,7 @@ def apply(self, sdfg):
for data_name, ranges in local_ranges.items():
for param, r in zip(variable_mapping[data_name], ranges):
# create new range from this subset and assign
rng = subsets.Range((r, ))
rng = subsets.Range((r,))
if param:
inferred_ranges[map_entry][param] = subsets.union(inferred_ranges[map_entry][param], rng)

Expand Down Expand Up @@ -457,9 +450,9 @@ def apply(self, sdfg):
reference_range_current = self.reference_range[param]

min_diff = symbolic.SymExpr(reference_range_current.min_element()[0] \
- target_range_current.min_element()[0])
- target_range_current.min_element()[0])
max_diff = symbolic.SymExpr(target_range_current.max_element()[0] \
- reference_range_current.max_element()[0])
- reference_range_current.max_element()[0])

try:
min_diff = symbolic.evaluate(min_diff, {})
Expand Down

0 comments on commit 490b415

Please sign in to comment.