From 85aa33b2fa4ed89d965f03c2af622add903b5ba0 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 8 Nov 2022 14:37:09 +0100 Subject: [PATCH 01/98] Type hint --- dace/sdfg/state.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 7dafd41a3e..4ef09012fe 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -7,7 +7,7 @@ import inspect import itertools import warnings -from typing import Any, AnyStr, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, overload +from typing import TYPE_CHECKING, Any, AnyStr, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, overload import dace from dace import data as dt @@ -24,6 +24,8 @@ from dace.sdfg.validation import validate_state from dace.subsets import Range, Subset +if TYPE_CHECKING: + import dace.sdfg.scope def _getdebuginfo(old_dinfo=None) -> dtypes.DebugInfo: """ Returns a DebugInfo object for the position that called this function. From e2c227ede645a99bbd0406af46d7cf927df480b6 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 8 Nov 2022 14:39:58 +0100 Subject: [PATCH 02/98] (re)scheduling-oriented view of an SDFG as a tree of control/data flow constructs --- dace/sdfg/analysis/schedule_tree/__init__.py | 0 .../analysis/schedule_tree/sdfg_to_tree.py | 272 ++++++++++++++ dace/sdfg/analysis/schedule_tree/treenodes.py | 340 ++++++++++++++++++ 3 files changed, 612 insertions(+) create mode 100644 dace/sdfg/analysis/schedule_tree/__init__.py create mode 100644 dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py create mode 100644 dace/sdfg/analysis/schedule_tree/treenodes.py diff --git a/dace/sdfg/analysis/schedule_tree/__init__.py b/dace/sdfg/analysis/schedule_tree/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py new file mode 100644 index 0000000000..4db015b59c --- /dev/null +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -0,0 +1,272 @@ +# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +import copy +from typing import Dict, List, Set +import dace +from dace.codegen import control_flow as cf +from dace.sdfg.sdfg import SDFG +from dace.sdfg.state import SDFGState +from dace.sdfg import utils as sdutil +import time +from dace.frontend.python.astutils import negate_expr +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.properties import CodeBlock + + +def state_schedule_tree(state: SDFGState, array_mapping: Dict[str, dace.Memlet]) -> List[tn.ScheduleTreeNode]: + """ + Use scope_tree to get nodes by scope. Traverse all scopes and return a string for each scope. + :return: A string for the whole state + """ + result: List[tn.ScheduleTreeNode] = [] + NODE_TO_SCOPE_TYPE = { + dace.nodes.MapEntry: tn.MapScope, + dace.nodes.ConsumeEntry: tn.ConsumeScope, + dace.nodes.PipelineEntry: tn.PipelineScope, + } + sdfg = state.parent + + scopes: List[List[tn.ScheduleTreeNode]] = [] + for node in sdutil.scope_aware_topological_sort(state): + if isinstance(node, dace.nodes.EntryNode): + scopes.append(result) + subnodes = [] + result.append(NODE_TO_SCOPE_TYPE[type(node)](node=node, children=subnodes)) + result = subnodes + elif isinstance(node, dace.nodes.ExitNode): + result = scopes.pop() + elif isinstance(node, dace.nodes.NestedSDFG): + nested_array_mapping = {} + for e in state.all_edges(node): + conn = e.dst_conn if e.dst is node else e.src_conn + if e.data.is_empty() or not conn: + continue + res = sdutil.map_view_to_array(node.sdfg.arrays[conn], sdfg.arrays[e.data.data], e.data.subset) + no_mapping = False + if res is None: + no_mapping = True + else: + mapping, expanded, squeezed = res + if expanded: # "newaxis" slices will be seen as views (for now) + no_mapping = True + else: + dname = e.data.data + if dname in array_mapping: # Trace through recursive nested SDFGs + dname = array_mapping[dname].data # TODO slice tracing + + # TODO: Add actual slice + new_memlet = copy.deepcopy(e.data) + new_memlet.data = dname + nested_array_mapping[conn] = new_memlet + + if no_mapping: # Must use view (nview = nested SDFG view) + result.append( + tn.NView(target=conn, + source=e.data.data, + memlet=e.data, + src_desc=sdfg.arrays[e.data.data], + view_desc=node.sdfg.arrays[conn])) + + # Insert the nested SDFG flattened + nested_stree = as_schedule_tree(node.sdfg, nested_array_mapping) + result.extend(nested_stree.children) + elif isinstance(node, dace.nodes.Tasklet): + in_memlets = {e.dst_conn: e.data for e in state.in_edges(node) if e.dst_conn} + out_memlets = {e.src_conn: e.data for e in state.out_edges(node) if e.src_conn} + result.append(tn.TaskletNode(node=node, in_memlets=in_memlets, out_memlets=out_memlets)) + elif isinstance(node, dace.nodes.LibraryNode): + in_memlets = {e.dst_conn: e.data for e in state.in_edges(node) if e.dst_conn} + out_memlets = {e.src_conn: e.data for e in state.out_edges(node) if e.src_conn} + result.append(tn.LibraryCall(node=node, in_memlets=in_memlets, out_memlets=out_memlets)) + elif isinstance(node, dace.nodes.AccessNode): + # Check type + desc = node.desc(sdfg) + vedge = None + if isinstance(desc, dace.data.View): + vedge = sdutil.get_view_edge(state, node) + + # Access nodes are only printed with corresponding memlets + for e in state.all_edges(node): + if e.data.is_empty(): + continue + conn = e.dst_conn if e.dst is node else e.src_conn + + # Reference + "set" connector + if conn == 'set': + result.append( + tn.RefSetNode(target=node.data, + memlet=e.data, + src_desc=sdfg.arrays[e.data.data], + ref_desc=sdfg.arrays[node.data])) + continue + # View edge + if e is vedge: + subset = e.data.get_src_subset(e, state) if e.dst is node else e.data.get_dst_subset(e, state) + vnode = sdutil.get_view_node(state, node) + new_memlet = copy.deepcopy(e.data) + new_memlet.data = node.data + new_memlet.subset = subset + new_memlet.other_subset = None + result.append( + tn.ViewNode(target=vnode.data, + source=node.data, + memlet=new_memlet, + src_desc=sdfg.arrays[vnode.data], + view_desc=sdfg.arrays[node.data])) + continue + + # Check if an incoming or outgoing memlet is a leaf (since the copy will be done at + # the innermost level) and leads to access node (otherwise taken care of in another node) + mpath = state.memlet_path(e) + if len(mpath) == 1 and e.dst is node: + # Special case: only annotate source in a simple copy + continue + if e.dst is node and mpath[-1] is e: + other = mpath[0].src + if not isinstance(other, dace.nodes.AccessNode): + continue + result.append(tn.CopyNode(target=node.data, memlet=e.data)) + continue + if e.src is node and mpath[0] is e: + other = mpath[-1].dst + if not isinstance(other, dace.nodes.AccessNode): + continue + result.append(tn.CopyNode(target=other.data, memlet=e.data)) + continue + + assert len(scopes) == 0 + + return result + + +def as_schedule_tree(sdfg: SDFG, array_mapping: Dict[str, dace.Memlet] = None) -> tn.ScheduleTreeScope: + """ + Converts an SDFG into a schedule tree. The schedule tree is a tree of nodes that represent the execution order of the SDFG. + Each node in the tree can either represent a single statement (symbol assignment, tasklet, copy, library node, etc.) or + a ``ScheduleTreeScope`` block (map, for-loop, pipeline, etc.) that contains other nodes. + + It can be used to generate code from an SDFG, or to perform schedule transformations on the SDFG. For example, + erasing an empty if branch, or merging two consecutive for-loops. The SDFG can then be reconstructed via the + ``from_schedule_tree`` function. + + :param sdfg: The SDFG to convert. + :param array_mapping: (Internal, should be left empty) A mapping from array names to memlets. + :return: A schedule tree representing the given SDFG. + """ + + from dace.transformation import helpers as xfh # Avoid import loop + array_mapping = array_mapping or {} + + # Split edges with assignments and conditions + xfh.split_interstate_edges(sdfg) + + # Create initial tree from CFG + cfg: cf.ControlFlow = cf.structured_control_flow_tree(sdfg, lambda _: '') + + # Traverse said tree (also into states) to create the schedule tree + def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.ScheduleTreeNode]: + result: List[tn.ScheduleTreeNode] = [] + if isinstance(node, cf.GeneralBlock): + subnodes: List[tn.ScheduleTreeNode] = [] + for n in node.elements: + subnodes.extend(totree(n, node)) + if not node.sequential: + # Nest in general block + result = [tn.GBlock(children=subnodes)] + else: + # Use the sub-nodes directly + result = subnodes + + elif isinstance(node, cf.SingleState): + result = state_schedule_tree(node.state, array_mapping) + + # Add interstate assignments unrelated to structured control flow + if parent is not None: + for e in sdfg.out_edges(node.state): + edge_body = [] + + if e not in parent.assignments_to_ignore: + for aname, aval in e.data.assignments.items(): + edge_body.append(tn.AssignNode(name=aname, value=CodeBlock(aval))) + + if not parent.sequential: + if e not in parent.gotos_to_ignore: + edge_body.append(tn.GotoNode(target=e.dst.label)) + else: + if e in parent.gotos_to_break: + edge_body.append(tn.BreakNode()) + elif e in parent.gotos_to_continue: + edge_body.append(tn.ContinueNode()) + + if e not in parent.gotos_to_ignore and not e.data.is_unconditional(): + if sdfg.out_degree(node.state) == 1 and parent.sequential: + # Conditional state in sequential block! Add "if not condition goto exit" + result.append( + tn.StateIfScope(condition=CodeBlock(negate_expr(e.data.condition)), + children=tn.GotoNode(target=None))) + result.extend(edge_body) + else: + # Add "if condition" with the body above + result.append(tn.StateIfScope(condition=e.data.condition, children=edge_body)) + else: + result.extend(edge_body) + + elif isinstance(node, cf.ForScope): + result.append(tn.ForScope(header=node, children=totree(node.body))) + elif isinstance(node, cf.IfScope): + result.append(tn.IfScope(condition=node.condition, children=totree(node.body))) + elif isinstance(node, cf.IfElseChain): + # Add "if" for the first condition, "elif"s for the rest + result.append(tn.IfScope(condition=node.body[0][0], children=totree(node.body[0][1]))) + for cond, body in node.body[1:]: + result.append(tn.ElifScope(condition=cond, children=totree(body))) + # "else goto exit" + result.append(tn.ElseScope(children=[tn.GotoNode(target=None)])) + elif isinstance(node, cf.WhileScope): + result.append(tn.WhileScope(header=node, children=totree(node.body))) + elif isinstance(node, cf.DoWhileScope): + result.append(tn.DoWhileScope(header=node, children=totree(node.body))) + else: + # e.g., "SwitchCaseScope" + raise tn.UnsupportedScopeException(type(node).__name__) + + if node.first_state is not None: + result = [tn.StateLabel(state=node.first_state)] + result + + return result + + # Recursive traversal of the control flow tree + result = tn.ScheduleTreeScope(children=totree(cfg)) + + # Clean up tree + remove_unused_labels(result) + + return result + + +def remove_unused_labels(stree: tn.ScheduleTreeScope): + class FindGotos(tn.ScheduleNodeVisitor): + def __init__(self): + self.gotos: Set[str] = set() + + def visit_GotoNode(self, node: tn.GotoNode): + if node.target is not None: + self.gotos.add(node.target) + + class RemoveLabels(tn.ScheduleNodeTransformer): + def __init__(self, labels_to_keep: Set[str]) -> None: + self.labels_to_keep = labels_to_keep + + def visit_StateLabel(self, node: tn.StateLabel): + if node.state.name not in self.labels_to_keep: + return None + return node + + fg = FindGotos() + fg.visit(stree) + return RemoveLabels(fg.gotos).visit(stree) + + +if __name__ == '__main__': + stree = as_schedule_tree(sdfg) + with open('output_stree.txt', 'w') as fp: + fp.write(stree.as_string(-1)) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py new file mode 100644 index 0000000000..9aaf13308f --- /dev/null +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -0,0 +1,340 @@ +# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +from dataclasses import dataclass +from dace import nodes, data, subsets +from dace.codegen import control_flow as cf +from dace.properties import CodeBlock +from dace.sdfg import SDFG +from dace.sdfg.state import SDFGState +from dace.memlet import Memlet +from typing import Dict, List, Optional + +INDENTATION = ' ' + + +class UnsupportedScopeException(Exception): + pass + + +@dataclass +class ScheduleTreeNode: + def as_string(self, indent: int = 0): + return indent * INDENTATION + 'UNSUPPORTED' + + +@dataclass +class ScheduleTreeScope(ScheduleTreeNode): + children: List['ScheduleTreeNode'] + + def __init__(self, children: Optional[List['ScheduleTreeNode']] = None): + self.children = children or [] + + def as_string(self, indent: int = 0): + return '\n'.join([child.as_string(indent + 1) for child in self.children]) + + # TODO: Get input/output memlets? + + +@dataclass +class ControlFlowScope(ScheduleTreeScope): + pass + + +@dataclass +class DataflowScope(ScheduleTreeScope): + node: nodes.EntryNode + + +@dataclass +class GBlock(ControlFlowScope): + """ + General control flow block. Contains a list of states + that can run in arbitrary order based on edges (gotos). + Normally contains irreducible control flow. + """ + def as_string(self, indent: int = 0): + result = indent * INDENTATION + 'gblock:\n' + return result + super().as_string(indent) + + pass + + +@dataclass +class StateLabel(ScheduleTreeNode): + state: SDFGState + + def as_string(self, indent: int = 0): + return indent * INDENTATION + f'label {self.state.name}:' + + +@dataclass +class GotoNode(ScheduleTreeNode): + target: Optional[str] = None #: If None, equivalent to "goto exit" or "return" + + def as_string(self, indent: int = 0): + name = self.target or 'exit' + return indent * INDENTATION + f'goto {name}' + + +@dataclass +class AssignNode(ScheduleTreeNode): + """ + Represents a symbol assignment that is not part of a structured control flow block. + """ + name: str + value: CodeBlock + + def as_string(self, indent: int = 0): + return indent * INDENTATION + f'assign {self.name} = {self.value.as_string}' + + +@dataclass +class ForScope(ControlFlowScope): + """ + For loop scope. + """ + header: cf.ForScope + + def as_string(self, indent: int = 0): + node = self.header + + result = (indent * INDENTATION + f'for {node.itervar} = {node.init}; {node.condition.as_string}; ' + f'{node.itervar} = {node.update}:\n') + return result + super().as_string(indent) + + +@dataclass +class WhileScope(ControlFlowScope): + """ + While loop scope. + """ + header: cf.WhileScope + + def as_string(self, indent: int = 0): + result = indent * INDENTATION + f'while {self.header.test.as_string}:\n' + return result + super().as_string(indent) + + +@dataclass +class DoWhileScope(ControlFlowScope): + """ + Do/While loop scope. + """ + header: cf.DoWhileScope + + def as_string(self, indent: int = 0): + header = indent * INDENTATION + 'do:\n' + footer = indent * INDENTATION + f'while {self.header.test.as_string}\n' + return header + super().as_string(indent) + footer + + +@dataclass +class IfScope(ControlFlowScope): + """ + If branch scope. + """ + condition: CodeBlock + + def as_string(self, indent: int = 0): + result = indent * INDENTATION + f'if {self.condition.as_string}:\n' + return result + super().as_string(indent) + + +@dataclass +class StateIfScope(IfScope): + """ + A special class of an if scope in general blocks for if statements that are part of a state transition. + """ + def as_string(self, indent: int = 0): + result = indent * INDENTATION + f'stateif {self.condition.as_string}:\n' + return result + super().as_string(indent) + + +@dataclass +class BreakNode(ScheduleTreeNode): + """ + Represents a break statement. + """ + def as_string(self, indent: int = 0): + return indent * INDENTATION + 'break' + + +@dataclass +class ContinueNode(ScheduleTreeNode): + """ + Represents a continue statement. + """ + def as_string(self, indent: int = 0): + return indent * INDENTATION + 'continue' + + +@dataclass +class ElifScope(ControlFlowScope): + """ + Else-if branch scope. + """ + condition: CodeBlock + + def as_string(self, indent: int = 0): + result = indent * INDENTATION + f'elif {self.condition.as_string}:\n' + return result + super().as_string(indent) + + +@dataclass +class ElseScope(ControlFlowScope): + """ + Else branch scope. + """ + def as_string(self, indent: int = 0): + result = indent * INDENTATION + 'else:\n' + return result + super().as_string(indent) + + +@dataclass +class MapScope(DataflowScope): + """ + Map scope. + """ + def as_string(self, indent: int = 0): + rangestr = ', '.join(subsets.Range.dim_to_string(d) for d in self.node.map.range) + result = indent * INDENTATION + f'map {", ".join(self.node.map.params)} in [{rangestr}]:\n' + return result + super().as_string(indent) + + +@dataclass +class ConsumeScope(DataflowScope): + """ + Consume scope. + """ + def as_string(self, indent: int = 0): + node: nodes.ConsumeEntry = self.node + cond = 'stream not empty' if node.consume.condition is None else node.consume.condition.as_string + result = indent * INDENTATION + f'consume (PE {node.consume.pe_index} out of {node.consume.num_pes}) while {cond}:\n' + return result + super().as_string(indent) + + +@dataclass +class PipelineScope(DataflowScope): + """ + Pipeline scope. + """ + def as_string(self, indent: int = 0): + rangestr = ', '.join(subsets.Range.dim_to_string(d) for d in self.node.map.range) + result = indent * INDENTATION + f'pipeline {", ".join(self.node.map.params)} in [{rangestr}]:\n' + return result + super().as_string(indent) + + +@dataclass +class TaskletNode(ScheduleTreeNode): + node: nodes.Tasklet + in_memlets: Dict[str, Memlet] + out_memlets: Dict[str, Memlet] + + def as_string(self, indent: int = 0): + in_memlets = ', '.join(f'{v}' for v in self.in_memlets.values()) + out_memlets = ', '.join(f'{v}' for v in self.out_memlets.values()) + return indent * INDENTATION + f'{out_memlets} = tasklet({in_memlets})' + + +@dataclass +class LibraryCall(ScheduleTreeNode): + node: nodes.LibraryNode + in_memlets: Dict[str, Memlet] + out_memlets: Dict[str, Memlet] + + def as_string(self, indent: int = 0): + in_memlets = ', '.join(f'{v}' for v in self.in_memlets.values()) + out_memlets = ', '.join(f'{v}' for v in self.out_memlets.values()) + libname = type(self.node).__name__ + # Get the properties of the library node without its superclasses + own_properties = ', '.join(f'{k}={getattr(self.node, k)}' for k, v in self.node.__properties__.items() + if v.owner not in {nodes.Node, nodes.CodeNode, nodes.LibraryNode}) + return indent * INDENTATION + f'{out_memlets} = library {libname}[{own_properties}]({in_memlets})' + + +@dataclass +class CopyNode(ScheduleTreeNode): + target: str + memlet: Memlet + + def as_string(self, indent: int = 0): + return indent * INDENTATION + f'{self.target} = copy {self.memlet}' + + +@dataclass +class ViewNode(ScheduleTreeNode): + target: str #: View name + source: str #: Viewed container name + memlet: Memlet + src_desc: data.Data + view_desc: data.Data + + def as_string(self, indent: int = 0): + return indent * INDENTATION + f'{self.target} = view {self.source}[{self.memlet}] as {self.view_desc}' + + +@dataclass +class NView(ViewNode): + """ + Nested SDFG view node. Subclass of a view that specializes in nested SDFG boundaries. + """ + def as_string(self, indent: int = 0): + return indent * INDENTATION + f'{self.target} = nview {self.source}[{self.memlet}] as {self.view_desc}' + + +@dataclass +class RefSetNode(ScheduleTreeNode): + """ + Reference set node. Sets a reference to a data container. + """ + target: str + memlet: Memlet + src_desc: data.Data + ref_desc: data.Data + + def as_string(self, indent: int = 0): + return indent * INDENTATION + f'{self.target} = refset to {self.memlet}' + + +# Classes based on Python's AST NodeVisitor/NodeTransformer for schedule tree nodes +class ScheduleNodeVisitor: + def visit(self, node: ScheduleTreeNode): + """Visit a node.""" + if isinstance(node, list): + return [self.visit(snode) for snode in node] + + method = 'visit_' + node.__class__.__name__ + visitor = getattr(self, method, self.generic_visit) + return visitor(node) + + def generic_visit(self, node: ScheduleTreeNode): + if isinstance(node, ScheduleTreeScope): + for child in node.children: + self.visit(child) + + +class ScheduleNodeTransformer(ScheduleNodeVisitor): + def visit(self, node: ScheduleTreeNode): + if isinstance(node, list): + result = [] + for snode in node: + new_node = self.visit(snode) + if new_node is not None: + result.append(new_node) + return result + + return super().visit(node) + + def generic_visit(self, node: ScheduleTreeNode): + new_values = [] + if isinstance(node, ScheduleTreeScope): + for value in node.children: + if isinstance(value, ScheduleTreeNode): + value = self.visit(value) + if value is None: + continue + elif not isinstance(value, ScheduleTreeNode): + new_values.extend(value) + continue + new_values.append(value) + node.children[:] = new_values + return node From 3c0ad6b3d3efecd804786c1c10586b8c28ec73b6 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 8 Nov 2022 16:38:14 +0100 Subject: [PATCH 03/98] Fix name bug with view and duplicate edge adds to the tree --- .../analysis/schedule_tree/sdfg_to_tree.py | 39 +++++++++++++++++-- dace/sdfg/analysis/schedule_tree/treenodes.py | 4 +- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 4db015b59c..c3673ce086 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -24,6 +24,7 @@ def state_schedule_tree(state: SDFGState, array_mapping: Dict[str, dace.Memlet]) dace.nodes.PipelineEntry: tn.PipelineScope, } sdfg = state.parent + edges_to_skip = set() scopes: List[List[tn.ScheduleTreeNode]] = [] for node in sdutil.scope_aware_topological_sort(state): @@ -84,10 +85,14 @@ def state_schedule_tree(state: SDFGState, array_mapping: Dict[str, dace.Memlet]) if isinstance(desc, dace.data.View): vedge = sdutil.get_view_edge(state, node) - # Access nodes are only printed with corresponding memlets + # Access nodes are only generated with corresponding memlets for e in state.all_edges(node): if e.data.is_empty(): continue + if e in edges_to_skip: + continue + edges_to_skip.add(e) # Only process each edge once + conn = e.dst_conn if e.dst is node else e.src_conn # Reference + "set" connector @@ -98,17 +103,25 @@ def state_schedule_tree(state: SDFGState, array_mapping: Dict[str, dace.Memlet]) src_desc=sdfg.arrays[e.data.data], ref_desc=sdfg.arrays[node.data])) continue + if e.src is node: + last_edge = state.memlet_path(e)[-1] + if isinstance(last_edge.dst, dace.nodes.AccessNode) and last_edge.dst_conn == 'set': + # Skip this edge, it is handled by the Reference node + edges_to_skip.remove(e) + continue + + # View edge if e is vedge: subset = e.data.get_src_subset(e, state) if e.dst is node else e.data.get_dst_subset(e, state) vnode = sdutil.get_view_node(state, node) new_memlet = copy.deepcopy(e.data) - new_memlet.data = node.data + new_memlet.data = vnode.data new_memlet.subset = subset new_memlet.other_subset = None result.append( - tn.ViewNode(target=vnode.data, - source=node.data, + tn.ViewNode(target=node.data, + source=vnode.data, memlet=new_memlet, src_desc=sdfg.arrays[vnode.data], view_desc=sdfg.arrays[node.data])) @@ -124,12 +137,30 @@ def state_schedule_tree(state: SDFGState, array_mapping: Dict[str, dace.Memlet]) other = mpath[0].src if not isinstance(other, dace.nodes.AccessNode): continue + + # Check if the other node is a view, and skip the view edge (handled by the view node) + other_desc = other.desc(sdfg) + if isinstance(other_desc, dace.data.View): + other_vedge = sdutil.get_view_edge(state, other) + if other_vedge is e: + edges_to_skip.remove(e) + continue + result.append(tn.CopyNode(target=node.data, memlet=e.data)) continue if e.src is node and mpath[0] is e: other = mpath[-1].dst if not isinstance(other, dace.nodes.AccessNode): continue + + # Check if the other node is a view, and skip the view edge (handled by the view node) + other_desc = other.desc(sdfg) + if isinstance(other_desc, dace.data.View): + other_vedge = sdutil.get_view_edge(state, other) + if other_vedge is e: + edges_to_skip.remove(e) + continue + result.append(tn.CopyNode(target=other.data, memlet=e.data)) continue diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 9aaf13308f..9edb813930 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -269,7 +269,7 @@ class ViewNode(ScheduleTreeNode): view_desc: data.Data def as_string(self, indent: int = 0): - return indent * INDENTATION + f'{self.target} = view {self.source}[{self.memlet}] as {self.view_desc}' + return indent * INDENTATION + f'{self.target} = view {self.memlet} as {self.view_desc.shape}' @dataclass @@ -278,7 +278,7 @@ class NView(ViewNode): Nested SDFG view node. Subclass of a view that specializes in nested SDFG boundaries. """ def as_string(self, indent: int = 0): - return indent * INDENTATION + f'{self.target} = nview {self.source}[{self.memlet}] as {self.view_desc}' + return indent * INDENTATION + f'{self.target} = nview {self.memlet} as {self.view_desc.shape}' @dataclass From 329d3929c381277250a3072f1e6d094179576dd1 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 11 Nov 2022 13:30:29 +0100 Subject: [PATCH 04/98] Bugfixes: add else clause, children as list --- dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index c3673ce086..153148fe22 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -233,7 +233,7 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche # Conditional state in sequential block! Add "if not condition goto exit" result.append( tn.StateIfScope(condition=CodeBlock(negate_expr(e.data.condition)), - children=tn.GotoNode(target=None))) + children=[tn.GotoNode(target=None)])) result.extend(edge_body) else: # Add "if condition" with the body above @@ -245,6 +245,8 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche result.append(tn.ForScope(header=node, children=totree(node.body))) elif isinstance(node, cf.IfScope): result.append(tn.IfScope(condition=node.condition, children=totree(node.body))) + if node.orelse is not None: + result.append(tn.ElseScope(children=totree(node.orelse))) elif isinstance(node, cf.IfElseChain): # Add "if" for the first condition, "elif"s for the rest result.append(tn.IfScope(condition=node.body[0][0], children=totree(node.body[0][1]))) @@ -298,6 +300,11 @@ def visit_StateLabel(self, node: tn.StateLabel): if __name__ == '__main__': - stree = as_schedule_tree(sdfg) + s = time.time() + sdfg = SDFG.from_file(sys.argv[1]) + print('Loaded SDFG in', time.time() - s, 'seconds') + s = time.time() + stree = as_schedule_tree(sdfg, in_place=True) + print('Created schedule tree in', time.time() - s, 'seconds') with open('output_stree.txt', 'w') as fp: - fp.write(stree.as_string(-1)) + fp.write(stree.as_string(-1) + '\n') From cda81e47ccd15a340bb7c36a8e351a84da26da93 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 11 Nov 2022 13:31:15 +0100 Subject: [PATCH 05/98] Remove duplicate labels --- dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 153148fe22..d183ef1a8d 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -271,13 +271,15 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche result = tn.ScheduleTreeScope(children=totree(cfg)) # Clean up tree - remove_unused_labels(result) + remove_unused_and_duplicate_labels(result) return result -def remove_unused_labels(stree: tn.ScheduleTreeScope): +def remove_unused_and_duplicate_labels(stree: tn.ScheduleTreeScope): + class FindGotos(tn.ScheduleNodeVisitor): + def __init__(self): self.gotos: Set[str] = set() @@ -286,12 +288,17 @@ def visit_GotoNode(self, node: tn.GotoNode): self.gotos.add(node.target) class RemoveLabels(tn.ScheduleNodeTransformer): + def __init__(self, labels_to_keep: Set[str]) -> None: self.labels_to_keep = labels_to_keep + self.labels_seen = set() def visit_StateLabel(self, node: tn.StateLabel): if node.state.name not in self.labels_to_keep: return None + if node.state.name in self.labels_seen: + return None + self.labels_seen.add(node.state.name) return node fg = FindGotos() From 09fbf2d3c0eeb7b70d82f7e56631cdd6af463157 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 11 Nov 2022 13:32:40 +0100 Subject: [PATCH 06/98] Revamp edge generation in schedule tree conversion and handle name collisions --- .../analysis/schedule_tree/sdfg_to_tree.py | 338 ++++++++++++------ dace/sdfg/analysis/schedule_tree/treenodes.py | 11 +- dace/sdfg/sdfg.py | 23 +- dace/transformation/helpers.py | 17 + 4 files changed, 285 insertions(+), 104 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index d183ef1a8d..04f1512344 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -2,19 +2,215 @@ import copy from typing import Dict, List, Set import dace +from dace import symbolic, data from dace.codegen import control_flow as cf from dace.sdfg.sdfg import SDFG from dace.sdfg.state import SDFGState -from dace.sdfg import utils as sdutil -import time +from dace.sdfg import utils as sdutil, graph as gr from dace.frontend.python.astutils import negate_expr from dace.sdfg.analysis.schedule_tree import treenodes as tn from dace.properties import CodeBlock +from dace.memlet import Memlet + +import time +import sys + + +def normalize_memlet(sdfg: SDFG, state: SDFGState, original: gr.MultiConnectorEdge[Memlet], data: str) -> Memlet: + """ + Normalizes a memlet to a given data descriptor. + + :param sdfg: The SDFG. + :param state: The state. + :param original: The original memlet. + :param data: The data descriptor. + :return: A new memlet. + """ + edge = copy.deepcopy(original) + edge.data.try_initialize(sdfg, state, edge) + + if edge.data.data == data: + return edge.data + + memlet = edge.data + if memlet._is_data_src: + new_subset, new_osubset = memlet.get_dst_subset(edge, state), memlet.get_src_subset(edge, state) + else: + new_subset, new_osubset = memlet.get_src_subset(edge, state), memlet.get_dst_subset(edge, state) + + memlet.data = data + memlet.subset = new_subset + memlet.other_subset = new_osubset + return memlet + + +def replace_memlets(sdfg: SDFG, array_mapping: Dict[str, Memlet]): + """ + Replaces all uses of data containers in memlets and interstate edges in an SDFG. + :param sdfg: The SDFG. + :param array_mapping: A mapping from internal data descriptor names to external memlets. + """ + # TODO replace, normalize, and compose + pass + + +def remove_name_collisions(sdfg: SDFG): + """ + Removes name collisions in nested SDFGs by renaming states, data containers, and symbols. + + :param sdfg: The SDFG. + """ + state_names_seen = set() + identifiers_seen = set() + + for nsdfg in sdfg.all_sdfgs_recursive(): + # Rename duplicate states + for state in nsdfg.nodes(): + if state.label in state_names_seen: + state.set_label(data.find_new_name(state.label, state_names_seen)) + state_names_seen.add(state.label) + + replacements: Dict[str, str] = {} + parent_node = nsdfg.parent_nsdfg_node + + # Rename duplicate data containers + for name, desc in nsdfg.arrays.items(): + # Will already be renamed during conversion + if parent_node is not None and not desc.transient: + continue + + if name in identifiers_seen: + new_name = data.find_new_name(name, identifiers_seen) + replacements[name] = new_name + name = new_name + identifiers_seen.add(name) + + # Rename duplicate symbols + for name in nsdfg.get_all_symbols(): + # Will already be renamed during conversion + if parent_node is not None and name in parent_node.symbol_mapping: + continue + + if name in identifiers_seen: + new_name = data.find_new_name(name, identifiers_seen) + replacements[name] = new_name + name = new_name + identifiers_seen.add(name) + + # Rename duplicate constants + for name in nsdfg.constants_prop.keys(): + if name in identifiers_seen: + new_name = data.find_new_name(name, identifiers_seen) + replacements[name] = new_name + name = new_name + identifiers_seen.add(name) + + # If there is a name collision, replace all uses of the old names with the new names + if replacements: + nsdfg.replace_dict(replacements) + + +def _make_view_node(state: SDFGState, edge: gr.MultiConnectorEdge[Memlet], view_name: str, + viewed_name: str) -> tn.ViewNode: + """ + Helper function to create a view schedule tree node from a memlet edge. + """ + sdfg = state.parent + normalized = normalize_memlet(sdfg, state, edge, viewed_name) + return tn.ViewNode(target=view_name, + source=viewed_name, + memlet=normalized, + src_desc=sdfg.arrays[viewed_name], + view_desc=sdfg.arrays[view_name]) + + +def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[Memlet], tn.ScheduleTreeNode]: + """ + Creates a dictionary mapping edges to their corresponding schedule tree nodes, if relevant. + + :param state: The state. + """ + result: Dict[gr.MultiConnectorEdge[Memlet], tn.ScheduleTreeNode] = {} + edges_to_ignore = set() + sdfg = state.parent + + for edge in state.edges(): + if edge in edges_to_ignore or edge in result: + continue + if edge.data.is_empty(): # Ignore empty memlets + edges_to_ignore.add(edge) + continue + + # Part of a memlet path - only consider innermost memlets + mtree = state.memlet_tree(edge) + all_edges = set(e for e in mtree) + leaves = set(mtree.leaves()) + edges_to_ignore.update(all_edges - leaves) + + # For every tree leaf, create a copy/view/reference set node as necessary + for e in leaves: + if e in edges_to_ignore or e in result: + continue + + # 1. Check for views + if isinstance(e.src, dace.nodes.AccessNode): + desc = e.src.desc(sdfg) + if isinstance(desc, dace.data.View): + vedge = sdutil.get_view_edge(state, e.src) + if e is vedge: + viewed_node = sdutil.get_view_node(state, e.src) + result[e] = _make_view_node(state, e, e.src.data, viewed_node.data) + continue + if isinstance(e.dst, dace.nodes.AccessNode): + desc = e.dst.desc(sdfg) + if isinstance(desc, dace.data.View): + vedge = sdutil.get_view_edge(state, e.dst) + if e is vedge: + viewed_node = sdutil.get_view_node(state, e.dst) + result[e] = _make_view_node(state, e, e.dst.data, viewed_node.data) + continue + + # 2. Check for reference sets + if isinstance(e.dst, dace.nodes.AccessNode) and e.dst_conn == 'set': + assert isinstance(e.dst.desc(sdfg), dace.data.Reference) + result[e] = tn.RefSetNode(target=e.data.data, + memlet=e.data, + src_desc=sdfg.arrays[e.data.data], + ref_desc=sdfg.arrays[e.dst.data]) + continue + + # 3. Check for copies + # Get both ends of the memlet path + mpath = state.memlet_path(e) + src = mpath[0].src + dst = mpath[-1].dst + if not isinstance(src, dace.nodes.AccessNode): + continue + if not isinstance(dst, dace.nodes.AccessNode): + continue + + # If the edge destination is the innermost node, it is a downward-pointing path + is_target_dst = e.dst is dst + + innermost_node = dst if is_target_dst else src + outermost_node = src if is_target_dst else dst + + # Normalize memlets to their innermost node, or source->destination if it is a same-scope edge + if e.src is src and e.dst is dst: + outermost_node = src + innermost_node = dst + + new_memlet = normalize_memlet(sdfg, state, e, outermost_node.data) + result[e] = tn.CopyNode(target=innermost_node.data, memlet=new_memlet) + + return result -def state_schedule_tree(state: SDFGState, array_mapping: Dict[str, dace.Memlet]) -> List[tn.ScheduleTreeNode]: +def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: """ - Use scope_tree to get nodes by scope. Traverse all scopes and return a string for each scope. + Use scope-aware topological sort to get nodes by scope and return the schedule tree of this state. + + :param state: The state. :return: A string for the whole state """ result: List[tn.ScheduleTreeNode] = [] @@ -24,7 +220,9 @@ def state_schedule_tree(state: SDFGState, array_mapping: Dict[str, dace.Memlet]) dace.nodes.PipelineEntry: tn.PipelineScope, } sdfg = state.parent - edges_to_skip = set() + + edge_to_stree: Dict[gr.MultiConnectorEdge[Memlet], tn.ScheduleTreeNode] = prepare_schedule_tree_edges(state) + edges_to_ignore = set() scopes: List[List[tn.ScheduleTreeNode]] = [] for node in sdutil.scope_aware_topological_sort(state): @@ -37,6 +235,13 @@ def state_schedule_tree(state: SDFGState, array_mapping: Dict[str, dace.Memlet]) result = scopes.pop() elif isinstance(node, dace.nodes.NestedSDFG): nested_array_mapping = {} + + # Replace symbols and memlets in nested SDFGs to match the namespace of the parent SDFG + # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes + symbolic.safe_replace(node.symbol_mapping, node.sdfg.replace_dict) + replace_memlets(node.sdfg, nested_array_mapping) + + # Create memlets for nested SDFG mapping, or nview schedule nodes if slice cannot be determined for e in state.all_edges(node): conn = e.dst_conn if e.dst is node else e.src_conn if e.data.is_empty() or not conn: @@ -50,14 +255,7 @@ def state_schedule_tree(state: SDFGState, array_mapping: Dict[str, dace.Memlet]) if expanded: # "newaxis" slices will be seen as views (for now) no_mapping = True else: - dname = e.data.data - if dname in array_mapping: # Trace through recursive nested SDFGs - dname = array_mapping[dname].data # TODO slice tracing - - # TODO: Add actual slice - new_memlet = copy.deepcopy(e.data) - new_memlet.data = dname - nested_array_mapping[conn] = new_memlet + nested_array_mapping[conn] = e.data if no_mapping: # Must use view (nview = nested SDFG view) result.append( @@ -67,8 +265,9 @@ def state_schedule_tree(state: SDFGState, array_mapping: Dict[str, dace.Memlet]) src_desc=sdfg.arrays[e.data.data], view_desc=node.sdfg.arrays[conn])) + # Insert the nested SDFG flattened - nested_stree = as_schedule_tree(node.sdfg, nested_array_mapping) + nested_stree = as_schedule_tree(node.sdfg, in_place=True, toplevel=False) result.extend(nested_stree.children) elif isinstance(node, dace.nodes.Tasklet): in_memlets = {e.dst_conn: e.data for e in state.in_edges(node) if e.dst_conn} @@ -79,97 +278,20 @@ def state_schedule_tree(state: SDFGState, array_mapping: Dict[str, dace.Memlet]) out_memlets = {e.src_conn: e.data for e in state.out_edges(node) if e.src_conn} result.append(tn.LibraryCall(node=node, in_memlets=in_memlets, out_memlets=out_memlets)) elif isinstance(node, dace.nodes.AccessNode): - # Check type - desc = node.desc(sdfg) - vedge = None - if isinstance(desc, dace.data.View): - vedge = sdutil.get_view_edge(state, node) - - # Access nodes are only generated with corresponding memlets + # If one of the neighboring edges has a schedule tree node attached to it, use that for e in state.all_edges(node): - if e.data.is_empty(): - continue - if e in edges_to_skip: - continue - edges_to_skip.add(e) # Only process each edge once - - conn = e.dst_conn if e.dst is node else e.src_conn - - # Reference + "set" connector - if conn == 'set': - result.append( - tn.RefSetNode(target=node.data, - memlet=e.data, - src_desc=sdfg.arrays[e.data.data], - ref_desc=sdfg.arrays[node.data])) - continue - if e.src is node: - last_edge = state.memlet_path(e)[-1] - if isinstance(last_edge.dst, dace.nodes.AccessNode) and last_edge.dst_conn == 'set': - # Skip this edge, it is handled by the Reference node - edges_to_skip.remove(e) - continue - - - # View edge - if e is vedge: - subset = e.data.get_src_subset(e, state) if e.dst is node else e.data.get_dst_subset(e, state) - vnode = sdutil.get_view_node(state, node) - new_memlet = copy.deepcopy(e.data) - new_memlet.data = vnode.data - new_memlet.subset = subset - new_memlet.other_subset = None - result.append( - tn.ViewNode(target=node.data, - source=vnode.data, - memlet=new_memlet, - src_desc=sdfg.arrays[vnode.data], - view_desc=sdfg.arrays[node.data])) - continue - - # Check if an incoming or outgoing memlet is a leaf (since the copy will be done at - # the innermost level) and leads to access node (otherwise taken care of in another node) - mpath = state.memlet_path(e) - if len(mpath) == 1 and e.dst is node: - # Special case: only annotate source in a simple copy - continue - if e.dst is node and mpath[-1] is e: - other = mpath[0].src - if not isinstance(other, dace.nodes.AccessNode): - continue - - # Check if the other node is a view, and skip the view edge (handled by the view node) - other_desc = other.desc(sdfg) - if isinstance(other_desc, dace.data.View): - other_vedge = sdutil.get_view_edge(state, other) - if other_vedge is e: - edges_to_skip.remove(e) - continue - - result.append(tn.CopyNode(target=node.data, memlet=e.data)) - continue - if e.src is node and mpath[0] is e: - other = mpath[-1].dst - if not isinstance(other, dace.nodes.AccessNode): - continue - - # Check if the other node is a view, and skip the view edge (handled by the view node) - other_desc = other.desc(sdfg) - if isinstance(other_desc, dace.data.View): - other_vedge = sdutil.get_view_edge(state, other) - if other_vedge is e: - edges_to_skip.remove(e) - continue - - result.append(tn.CopyNode(target=other.data, memlet=e.data)) + if e in edges_to_ignore: continue + if e in edge_to_stree: + result.append(edge_to_stree[e]) + edges_to_ignore.add(e) assert len(scopes) == 0 return result -def as_schedule_tree(sdfg: SDFG, array_mapping: Dict[str, dace.Memlet] = None) -> tn.ScheduleTreeScope: +def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) -> tn.ScheduleTreeScope: """ Converts an SDFG into a schedule tree. The schedule tree is a tree of nodes that represent the execution order of the SDFG. Each node in the tree can either represent a single statement (symbol assignment, tasklet, copy, library node, etc.) or @@ -180,16 +302,30 @@ def as_schedule_tree(sdfg: SDFG, array_mapping: Dict[str, dace.Memlet] = None) - ``from_schedule_tree`` function. :param sdfg: The SDFG to convert. - :param array_mapping: (Internal, should be left empty) A mapping from array names to memlets. + :param in_place: If True, the SDFG is modified in-place. Otherwise, a copy is made. Note that the SDFG might not be + usable after the conversion if ``in_place`` is True! :return: A schedule tree representing the given SDFG. """ - from dace.transformation import helpers as xfh # Avoid import loop - array_mapping = array_mapping or {} + + if not in_place: + sdfg = copy.deepcopy(sdfg) + + # Prepare SDFG for conversion + ############################# # Split edges with assignments and conditions xfh.split_interstate_edges(sdfg) + # Replace code->code edges with data<->code edges + xfh.replace_code_to_code_edges(sdfg) + + if toplevel: # Top-level SDFG preparation (only perform once) + # Handle name collisions (in arrays, state labels, symbols) + remove_name_collisions(sdfg) + + ############################# + # Create initial tree from CFG cfg: cf.ControlFlow = cf.structured_control_flow_tree(sdfg, lambda _: '') @@ -208,7 +344,7 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche result = subnodes elif isinstance(node, cf.SingleState): - result = state_schedule_tree(node.state, array_mapping) + result = state_schedule_tree(node.state) # Add interstate assignments unrelated to structured control flow if parent is not None: diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 9edb813930..08e4d07331 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -257,7 +257,16 @@ class CopyNode(ScheduleTreeNode): memlet: Memlet def as_string(self, indent: int = 0): - return indent * INDENTATION + f'{self.target} = copy {self.memlet}' + if any(s != 0 for s in self.memlet.other_subset.min_element()): + offset = f'[{self.memlet.other_subset}]' + else: + offset = '' + if self.memlet.wcr is not None: + wcr = f' with {self.memlet.wcr}' + else: + wcr = '' + + return indent * INDENTATION + f'{self.target}{offset} = copy {self.memlet.data}[{self.memlet.subset}]{wcr}' @dataclass diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 31406798ea..011ea56922 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -212,7 +212,6 @@ def free_symbols(self) -> Set[str]: """ Returns a set of symbols used in this edge's properties. """ return self.read_symbols() - set(self.assignments.keys()) - def replace_dict(self, repl: Dict[str, str], replace_keys=True) -> None: """ Replaces all given keys with their corresponding values. @@ -1261,6 +1260,26 @@ def free_symbols(self) -> Set[str]: # Subtract symbols defined in inter-state edges and constants return free_syms - defined_syms + def get_all_symbols(self) -> Set[str]: + """ + Returns a set of all symbol names that are used by the SDFG. + """ + # Exclude constants and data descriptor names + exclude = set(self.arrays.keys()) | set(self.constants_prop.keys()) + + syms = set() + + # Start with the set of SDFG free symbols + syms |= set(self.symbols.keys()) + + # Add inter-state symbols + for e in self.edges(): + syms |= set(e.data.assignments.keys()) + syms |= e.data.free_symbols + + # Subtract exluded symbols + return syms - exclude + def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: """ Determines what data containers are read and written in this SDFG. Does @@ -2230,7 +2249,7 @@ def compile(self, output_file=None, validate=True) -> \ index += 1 if self.name != sdfg.name: warnings.warn('SDFG "%s" is already loaded by another object, ' - 'recompiling under a different name.' % self.name) + 'recompiling under a different name.' % self.name) try: # Fill in scope entry/exit connectors diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 1abbd0856c..621c7822f7 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -1286,3 +1286,20 @@ def redirect_edge(state: SDFGState, new_edge = state.add_edge(new_src or edge.src, new_src_conn or edge.src_conn, new_dst or edge.dst, new_dst_conn or edge.dst_conn, memlet) return new_edge + + +def replace_code_to_code_edges(sdfg: SDFG): + """ + Adds access nodes between all code->code edges in each state. + + :param sdfg: The SDFG to process. + """ + for state in sdfg.nodes(): + for edge in state.edges(): + if not isinstance(edge.src, nodes.CodeNode) or not isinstance(edge.dst, nodes.CodeNode): + continue + # Add access nodes + aname = state.add_access(edge.data.data) + state.add_edge(edge.src, edge.src_conn, aname, None, edge.data) + state.add_edge(aname, None, edge.dst, edge.dst_conn, copy.deepcopy(edge.data)) + state.remove_edge(edge) From d1774a970c24330d5eedc14fb4d34cc65f0c94bf Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 11 Nov 2022 13:51:27 +0100 Subject: [PATCH 07/98] Refactor schedule tree passes and add an example pass --- dace/sdfg/analysis/schedule_tree/passes.py | 63 +++++++++++++++++++ .../analysis/schedule_tree/sdfg_to_tree.py | 35 +---------- 2 files changed, 66 insertions(+), 32 deletions(-) create mode 100644 dace/sdfg/analysis/schedule_tree/passes.py diff --git a/dace/sdfg/analysis/schedule_tree/passes.py b/dace/sdfg/analysis/schedule_tree/passes.py new file mode 100644 index 0000000000..52a58adc32 --- /dev/null +++ b/dace/sdfg/analysis/schedule_tree/passes.py @@ -0,0 +1,63 @@ +# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +""" +Assortment of passes for schedule trees. +""" + +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from typing import Set + + +def remove_unused_and_duplicate_labels(stree: tn.ScheduleTreeScope): + """ + Removes unused and duplicate labels from the schedule tree. + + :param stree: The schedule tree to remove labels from. + """ + + class FindGotos(tn.ScheduleNodeVisitor): + + def __init__(self): + self.gotos: Set[str] = set() + + def visit_GotoNode(self, node: tn.GotoNode): + if node.target is not None: + self.gotos.add(node.target) + + class RemoveLabels(tn.ScheduleNodeTransformer): + + def __init__(self, labels_to_keep: Set[str]) -> None: + self.labels_to_keep = labels_to_keep + self.labels_seen = set() + + def visit_StateLabel(self, node: tn.StateLabel): + if node.state.name not in self.labels_to_keep: + return None + if node.state.name in self.labels_seen: + return None + self.labels_seen.add(node.state.name) + return node + + fg = FindGotos() + fg.visit(stree) + return RemoveLabels(fg.gotos).visit(stree) + + +def remove_empty_scopes(stree: tn.ScheduleTreeScope): + """ + Removes empty scopes from the schedule tree. + + :warning: This pass is not safe to use for for-loops, as it will remove indices that may be used after the loop. + """ + + class RemoveEmptyScopes(tn.ScheduleNodeTransformer): + + def visit(self, node: tn.ScheduleTreeNode): + if not isinstance(node, tn.ScheduleTreeScope): + return super().visit(node) + + if len(node.children) == 0: + return None + + return self.generic_visit(node) + + return RemoveEmptyScopes().visit(stree) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 04f1512344..7fc5ddffe8 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -8,7 +8,7 @@ from dace.sdfg.state import SDFGState from dace.sdfg import utils as sdutil, graph as gr from dace.frontend.python.astutils import negate_expr -from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.sdfg.analysis.schedule_tree import treenodes as tn, passes as stpasses from dace.properties import CodeBlock from dace.memlet import Memlet @@ -407,41 +407,11 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche result = tn.ScheduleTreeScope(children=totree(cfg)) # Clean up tree - remove_unused_and_duplicate_labels(result) + stpasses.remove_unused_and_duplicate_labels(result) return result -def remove_unused_and_duplicate_labels(stree: tn.ScheduleTreeScope): - - class FindGotos(tn.ScheduleNodeVisitor): - - def __init__(self): - self.gotos: Set[str] = set() - - def visit_GotoNode(self, node: tn.GotoNode): - if node.target is not None: - self.gotos.add(node.target) - - class RemoveLabels(tn.ScheduleNodeTransformer): - - def __init__(self, labels_to_keep: Set[str]) -> None: - self.labels_to_keep = labels_to_keep - self.labels_seen = set() - - def visit_StateLabel(self, node: tn.StateLabel): - if node.state.name not in self.labels_to_keep: - return None - if node.state.name in self.labels_seen: - return None - self.labels_seen.add(node.state.name) - return node - - fg = FindGotos() - fg.visit(stree) - return RemoveLabels(fg.gotos).visit(stree) - - if __name__ == '__main__': s = time.time() sdfg = SDFG.from_file(sys.argv[1]) @@ -449,5 +419,6 @@ def visit_StateLabel(self, node: tn.StateLabel): s = time.time() stree = as_schedule_tree(sdfg, in_place=True) print('Created schedule tree in', time.time() - s, 'seconds') + with open('output_stree.txt', 'w') as fp: fp.write(stree.as_string(-1) + '\n') From 4518c340692231e6b4f74b3f3c42b3e555090e6a Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 11 Nov 2022 14:11:41 +0100 Subject: [PATCH 08/98] Support dynamic scope inputs --- .../analysis/schedule_tree/sdfg_to_tree.py | 22 ++++++++++++++++--- dace/sdfg/analysis/schedule_tree/treenodes.py | 4 ++-- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 7fc5ddffe8..64b06c0c48 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -186,7 +186,7 @@ def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[ dst = mpath[-1].dst if not isinstance(src, dace.nodes.AccessNode): continue - if not isinstance(dst, dace.nodes.AccessNode): + if not isinstance(dst, (dace.nodes.AccessNode, dace.nodes.EntryNode)): continue # If the edge destination is the innermost node, it is a downward-pointing path @@ -200,8 +200,15 @@ def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[ outermost_node = src innermost_node = dst - new_memlet = normalize_memlet(sdfg, state, e, outermost_node.data) - result[e] = tn.CopyNode(target=innermost_node.data, memlet=new_memlet) + if isinstance(dst, dace.nodes.EntryNode): + # Special case: dynamic map range has no data + target_name = e.dst_conn + new_memlet = e.data + else: + target_name = innermost_node.data + new_memlet = normalize_memlet(sdfg, state, e, outermost_node.data) + + result[e] = tn.CopyNode(target=target_name, memlet=new_memlet) return result @@ -227,6 +234,15 @@ def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: scopes: List[List[tn.ScheduleTreeNode]] = [] for node in sdutil.scope_aware_topological_sort(state): if isinstance(node, dace.nodes.EntryNode): + # Handle dynamic scope inputs + for e in state.in_edges(node): + if e in edges_to_ignore: + continue + if e in edge_to_stree: + result.append(edge_to_stree[e]) + edges_to_ignore.add(e) + + # Create scope node and add to stack scopes.append(result) subnodes = [] result.append(NODE_TO_SCOPE_TYPE[type(node)](node=node, children=subnodes)) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 08e4d07331..b58df2db05 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -257,7 +257,7 @@ class CopyNode(ScheduleTreeNode): memlet: Memlet def as_string(self, indent: int = 0): - if any(s != 0 for s in self.memlet.other_subset.min_element()): + if self.memlet.other_subset is not None and any(s != 0 for s in self.memlet.other_subset.min_element()): offset = f'[{self.memlet.other_subset}]' else: offset = '' @@ -265,7 +265,7 @@ def as_string(self, indent: int = 0): wcr = f' with {self.memlet.wcr}' else: wcr = '' - + return indent * INDENTATION + f'{self.target}{offset} = copy {self.memlet.data}[{self.memlet.subset}]{wcr}' From 7949ebdccd653ea1d84b870bb4806d7c2348f7ae Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 11 Nov 2022 16:47:33 +0100 Subject: [PATCH 09/98] Explicit dynamic scope input copy node --- .../analysis/schedule_tree/sdfg_to_tree.py | 7 ++---- dace/sdfg/analysis/schedule_tree/treenodes.py | 24 +++++++++++++++++++ 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 64b06c0c48..e4ef82014c 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -202,13 +202,11 @@ def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[ if isinstance(dst, dace.nodes.EntryNode): # Special case: dynamic map range has no data - target_name = e.dst_conn - new_memlet = e.data + result[e] = tn.DynScopeCopyNode(target=e.dst_conn, memlet=e.data) else: target_name = innermost_node.data new_memlet = normalize_memlet(sdfg, state, e, outermost_node.data) - - result[e] = tn.CopyNode(target=target_name, memlet=new_memlet) + result[e] = tn.CopyNode(target=target_name, memlet=new_memlet) return result @@ -281,7 +279,6 @@ def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: src_desc=sdfg.arrays[e.data.data], view_desc=node.sdfg.arrays[conn])) - # Insert the nested SDFG flattened nested_stree = as_schedule_tree(node.sdfg, in_place=True, toplevel=False) result.extend(nested_stree.children) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index b58df2db05..2e17fa378a 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -17,6 +17,7 @@ class UnsupportedScopeException(Exception): @dataclass class ScheduleTreeNode: + def as_string(self, indent: int = 0): return indent * INDENTATION + 'UNSUPPORTED' @@ -51,6 +52,7 @@ class GBlock(ControlFlowScope): that can run in arbitrary order based on edges (gotos). Normally contains irreducible control flow. """ + def as_string(self, indent: int = 0): result = indent * INDENTATION + 'gblock:\n' return result + super().as_string(indent) @@ -144,6 +146,7 @@ class StateIfScope(IfScope): """ A special class of an if scope in general blocks for if statements that are part of a state transition. """ + def as_string(self, indent: int = 0): result = indent * INDENTATION + f'stateif {self.condition.as_string}:\n' return result + super().as_string(indent) @@ -154,6 +157,7 @@ class BreakNode(ScheduleTreeNode): """ Represents a break statement. """ + def as_string(self, indent: int = 0): return indent * INDENTATION + 'break' @@ -163,6 +167,7 @@ class ContinueNode(ScheduleTreeNode): """ Represents a continue statement. """ + def as_string(self, indent: int = 0): return indent * INDENTATION + 'continue' @@ -184,6 +189,7 @@ class ElseScope(ControlFlowScope): """ Else branch scope. """ + def as_string(self, indent: int = 0): result = indent * INDENTATION + 'else:\n' return result + super().as_string(indent) @@ -194,6 +200,7 @@ class MapScope(DataflowScope): """ Map scope. """ + def as_string(self, indent: int = 0): rangestr = ', '.join(subsets.Range.dim_to_string(d) for d in self.node.map.range) result = indent * INDENTATION + f'map {", ".join(self.node.map.params)} in [{rangestr}]:\n' @@ -205,6 +212,7 @@ class ConsumeScope(DataflowScope): """ Consume scope. """ + def as_string(self, indent: int = 0): node: nodes.ConsumeEntry = self.node cond = 'stream not empty' if node.consume.condition is None else node.consume.condition.as_string @@ -217,6 +225,7 @@ class PipelineScope(DataflowScope): """ Pipeline scope. """ + def as_string(self, indent: int = 0): rangestr = ', '.join(subsets.Range.dim_to_string(d) for d in self.node.map.range) result = indent * INDENTATION + f'pipeline {", ".join(self.node.map.params)} in [{rangestr}]:\n' @@ -269,6 +278,18 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target}{offset} = copy {self.memlet.data}[{self.memlet.subset}]{wcr}' +@dataclass +class DynScopeCopyNode(ScheduleTreeNode): + """ + A special case of a copy node that is used in dynamic scope inputs (e.g., dynamic map ranges). + """ + target: str + memlet: Memlet + + def as_string(self, indent: int = 0): + return indent * INDENTATION + f'{self.target} = dscopy {self.memlet.data}[{self.memlet.subset}]' + + @dataclass class ViewNode(ScheduleTreeNode): target: str #: View name @@ -286,6 +307,7 @@ class NView(ViewNode): """ Nested SDFG view node. Subclass of a view that specializes in nested SDFG boundaries. """ + def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target} = nview {self.memlet} as {self.view_desc.shape}' @@ -306,6 +328,7 @@ def as_string(self, indent: int = 0): # Classes based on Python's AST NodeVisitor/NodeTransformer for schedule tree nodes class ScheduleNodeVisitor: + def visit(self, node: ScheduleTreeNode): """Visit a node.""" if isinstance(node, list): @@ -322,6 +345,7 @@ def generic_visit(self, node: ScheduleTreeNode): class ScheduleNodeTransformer(ScheduleNodeVisitor): + def visit(self, node: ScheduleTreeNode): if isinstance(node, list): result = [] From adbf8fc1f2564dba95baf1f734d25748f51e40de Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 9 Dec 2022 14:25:06 +0100 Subject: [PATCH 10/98] Add memlet parsing and replacement functionality for inter-state edge de-aliasing --- dace/frontend/python/memlet_parser.py | 4 +- dace/sdfg/memlet_utils.py | 83 +++++++++++++++++++++++++++ tests/sdfg/memlet_utils_test.py | 67 +++++++++++++++++++++ 3 files changed, 152 insertions(+), 2 deletions(-) create mode 100644 dace/sdfg/memlet_utils.py create mode 100644 tests/sdfg/memlet_utils_test.py diff --git a/dace/frontend/python/memlet_parser.py b/dace/frontend/python/memlet_parser.py index 6ef627a430..05ca8a5b82 100644 --- a/dace/frontend/python/memlet_parser.py +++ b/dace/frontend/python/memlet_parser.py @@ -200,7 +200,7 @@ def _fill_missing_slices(das, ast_ndslice, array, indices): def parse_memlet_subset(array: data.Data, node: Union[ast.Name, ast.Subscript], das: Dict[str, Any], - parsed_slice: Any = None) -> Tuple[subsets.Range, List[int]]: + parsed_slice: Any = None) -> Tuple[subsets.Range, List[int], List[int]]: """ Parses an AST subset and returns access range, as well as new dimensions to add. @@ -209,7 +209,7 @@ def parse_memlet_subset(array: data.Data, e.g., negative indices or empty shapes). :param node: AST node representing whole array or subset thereof. :param das: Dictionary of defined arrays and symbols mapped to their values. - :return: A 2-tuple of (subset, list of new axis indices). + :return: A 3-tuple of (subset, list of new axis indices, list of index-to-array-dimension correspondence). """ # Get memlet range ndslice = [(0, s - 1, 1) for s in array.shape] diff --git a/dace/sdfg/memlet_utils.py b/dace/sdfg/memlet_utils.py new file mode 100644 index 0000000000..c0c5201e50 --- /dev/null +++ b/dace/sdfg/memlet_utils.py @@ -0,0 +1,83 @@ +# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. + +import ast +from dace.frontend.python import astutils, memlet_parser +from dace.sdfg import SDFG, SDFGState, nodes +from dace.sdfg import graph as gr +from dace.sdfg import utils as sdutil +from dace.properties import CodeBlock +from dace import data, subsets, Memlet +from typing import Callable, Dict, Optional, Set, Union + + +class MemletReplacer(ast.NodeTransformer): + """ + Iterates over all memlet expressions (name or subscript with matching array in SDFG) in a code block. + The callable can also return another memlet to replace the current one. + """ + + def __init__(self, + arrays: Dict[str, data.Data], + process: Callable[[Memlet], Union[Memlet, None]], + array_filter: Optional[Set[str]] = None) -> None: + """ + Create a new memlet replacer. + + :param arrays: A mapping from array names to data descriptors. + :param process: A callable that takes a memlet and returns a memlet or None. + :param array_filter: An optional subset of array names to process. + """ + self.process = process + self.arrays = arrays + self.array_filter = array_filter or self.arrays.keys() + + def _parse_memlet(self, node: Union[ast.Name, ast.Subscript]) -> Memlet: + """ + Parses a memlet from a subscript or name node. + + :param node: The node to parse. + :return: The parsed memlet. + """ + # Get array name + if isinstance(node, ast.Name): + data = node.id + elif isinstance(node, ast.Subscript): + data = node.value.id + else: + raise TypeError('Expected Name or Subscript') + + # Parse memlet subset + array = self.arrays[data] + subset, newaxes, _ = memlet_parser.parse_memlet_subset(array, node, self.arrays) + if newaxes: + raise NotImplementedError('Adding new axes to memlets is not supported') + + return Memlet(data=data, subset=subset) + + def _memlet_to_ast(self, memlet: Memlet) -> ast.Subscript: + """ + Converts a memlet to a subscript node. + + :param memlet: The memlet to convert. + :return: The converted node. + """ + return ast.parse(f'{memlet.data}[{memlet.subset}]').body[0].value + + def _replace(self, node: Union[ast.Name, ast.Subscript]) -> ast.Subscript: + cur_memlet = self._parse_memlet(node) + new_memlet = self.process(cur_memlet) + if new_memlet is None: + return node + + new_node = self._memlet_to_ast(new_memlet) + return ast.copy_location(new_node, node) + + def visit_Name(self, node: ast.Name): + if node.id in self.array_filter: + return self._replace(node) + return self.generic_visit(node) + + def visit_Subscript(self, node: ast.Subscript): + if isinstance(node.value, ast.Name) and node.value.id in self.array_filter: + return self._replace(node) + return self.generic_visit(node) diff --git a/tests/sdfg/memlet_utils_test.py b/tests/sdfg/memlet_utils_test.py new file mode 100644 index 0000000000..752b9ef55d --- /dev/null +++ b/tests/sdfg/memlet_utils_test.py @@ -0,0 +1,67 @@ +# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. + +import dace +import numpy as np +import pytest +from dace.sdfg import memlet_utils as mu + + +def _replace_zero_with_one(memlet: dace.Memlet) -> dace.Memlet: + for i, s in enumerate(memlet.subset): + if s == 0: + memlet.subset[i] = 1 + return memlet + + +@pytest.mark.parametrize('filter_type', ['none', 'same_array', 'different_array']) +def test_replace_memlet(filter_type): + # Prepare SDFG + sdfg = dace.SDFG('replace_memlet') + sdfg.add_array('A', [2, 2], dace.float64) + sdfg.add_array('B', [1], dace.float64) + state1 = sdfg.add_state() + state2 = sdfg.add_state() + state3 = sdfg.add_state() + end_state = sdfg.add_state() + sdfg.add_edge(state1, state2, dace.InterstateEdge('A[0, 0] > 0')) + sdfg.add_edge(state1, state3, dace.InterstateEdge('A[0, 0] <= 0')) + sdfg.add_edge(state2, end_state, dace.InterstateEdge()) + sdfg.add_edge(state3, end_state, dace.InterstateEdge()) + + t2 = state2.add_tasklet('write_one', {}, {'out'}, 'out = 1') + t3 = state3.add_tasklet('write_two', {}, {'out'}, 'out = 2') + w2 = state2.add_write('B') + w3 = state3.add_write('B') + state2.add_memlet_path(t2, w2, src_conn='out', memlet=dace.Memlet('B')) + state3.add_memlet_path(t3, w3, src_conn='out', memlet=dace.Memlet('B')) + + # Filter memlets + if filter_type == 'none': + filter = set() + elif filter_type == 'same_array': + filter = {'A'} + elif filter_type == 'different_array': + filter = {'B'} + + # Replace memlets in conditions + replacer = mu.MemletReplacer(sdfg.arrays, _replace_zero_with_one, filter) + for e in sdfg.edges(): + e.data.condition.code[0] = replacer.visit(e.data.condition.code[0]) + + # Compile and run + sdfg.compile() + + A = np.array([[1, 1], [1, -1]], dtype=np.float64) + B = np.array([0], dtype=np.float64) + sdfg(A=A, B=B) + + if filter_type in {'none', 'same_array'}: + assert B[0] == 2 + else: + assert B[0] == 1 + + +if __name__ == '__main__': + test_replace_memlet('none') + test_replace_memlet('same_array') + test_replace_memlet('different_array') From 828eff32155d21b8e4addd11bbc7b73f4e34d95b Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Mon, 12 Dec 2022 15:57:18 +0100 Subject: [PATCH 11/98] Started working on ScheduleTree utilities. --- dace/sdfg/analysis/schedule_tree/utils.py | 52 +++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 dace/sdfg/analysis/schedule_tree/utils.py diff --git a/dace/sdfg/analysis/schedule_tree/utils.py b/dace/sdfg/analysis/schedule_tree/utils.py new file mode 100644 index 0000000000..a1ce9a1f6c --- /dev/null +++ b/dace/sdfg/analysis/schedule_tree/utils.py @@ -0,0 +1,52 @@ +# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +from dace.sdfg import SDFG, SDFGState +from dace.sdfg import nodes as snodes +from dace.sdfg.analysis.schedule_tree import treenodes as tnodes +from dace.sdfg.graph import NodeNotFoundError +from typing import Tuple + + +__df_nodes = ( + # tnodes.ViewNode, tnodes.RefSetNode, + # tnodes.CopyNode, tnodes.DynScopeCopyNode, + tnodes.TaskletNode, tnodes.LibraryCall, + tnodes.MapScope, tnodes.ConsumeScope, tnodes.PipelineScope) + + +def find_tnode_in_sdfg(tnode: tnodes.ScheduleTreeNode, top_level_sdfg: SDFG) -> Tuple[snodes.Node, SDFGState, SDFG]: + if tnode not in __df_nodes: + raise NotImplementedError(f"The `find_dfnode_in_sdfg` does not support {type(tnode)} nodes.") + for n, s in top_level_sdfg.all_nodes_recursive(): + if n is tnode.node: + return n, s, s.parent + raise NodeNotFoundError(f"Node {tnode} not found in SDFG.") + + +def find_snode_in_tree(snode: snodes.Node, tree: tnodes.ScheduleTreeNode) -> Tuple[tnodes.ScheduleTreeScope, tnodes.ScheduleTreeNode]: + pnode = None + cnode = None + frontier = [(tree, child) for child in tree.children] + while frontier: + parent, child = frontier.pop() + if hasattr(child, 'node') and child.node is snode: + pnode = parent + cnode = child + break + frontier.extend([(child, c) for c in child.children]) + if not pnode: + raise NodeNotFoundError(f"Node {snode} not found in ScheduleTree.") + return pnode, cnode + + +def find_parent(tnode: tnodes.ScheduleTreeNode, tree: tnodes.ScheduleTreeNode) -> tnodes.ScheduleTreeScope: + pnode = None + frontier = [(tree, child) for child in tree.children] + while frontier: + parent, child = frontier.pop() + if child is tnode: + pnode = parent + break + frontier.extend([(child, c) for c in child.children]) + if not pnode: + raise NodeNotFoundError(f"Node {tnode} not found in ScheduleTree.") + return pnode From 6896955785a8f450b51df172f50e09080552f698 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Mon, 12 Dec 2022 15:58:07 +0100 Subject: [PATCH 12/98] yapf --- dace/sdfg/analysis/schedule_tree/utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/utils.py b/dace/sdfg/analysis/schedule_tree/utils.py index a1ce9a1f6c..75c2785320 100644 --- a/dace/sdfg/analysis/schedule_tree/utils.py +++ b/dace/sdfg/analysis/schedule_tree/utils.py @@ -5,12 +5,14 @@ from dace.sdfg.graph import NodeNotFoundError from typing import Tuple - __df_nodes = ( # tnodes.ViewNode, tnodes.RefSetNode, # tnodes.CopyNode, tnodes.DynScopeCopyNode, - tnodes.TaskletNode, tnodes.LibraryCall, - tnodes.MapScope, tnodes.ConsumeScope, tnodes.PipelineScope) + tnodes.TaskletNode, + tnodes.LibraryCall, + tnodes.MapScope, + tnodes.ConsumeScope, + tnodes.PipelineScope) def find_tnode_in_sdfg(tnode: tnodes.ScheduleTreeNode, top_level_sdfg: SDFG) -> Tuple[snodes.Node, SDFGState, SDFG]: @@ -22,7 +24,8 @@ def find_tnode_in_sdfg(tnode: tnodes.ScheduleTreeNode, top_level_sdfg: SDFG) -> raise NodeNotFoundError(f"Node {tnode} not found in SDFG.") -def find_snode_in_tree(snode: snodes.Node, tree: tnodes.ScheduleTreeNode) -> Tuple[tnodes.ScheduleTreeScope, tnodes.ScheduleTreeNode]: +def find_snode_in_tree(snode: snodes.Node, + tree: tnodes.ScheduleTreeNode) -> Tuple[tnodes.ScheduleTreeScope, tnodes.ScheduleTreeNode]: pnode = None cnode = None frontier = [(tree, child) for child in tree.children] From 2a94d711ef5b08ce0abe798a76d543398af30bcc Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Tue, 13 Dec 2022 00:10:10 +0100 Subject: [PATCH 13/98] WIP: MapFission on ScheduleTree. --- .../analysis/schedule_tree/transformations.py | 131 ++++++++++++++++++ dace/sdfg/analysis/schedule_tree/utils.py | 36 ++++- 2 files changed, 164 insertions(+), 3 deletions(-) create mode 100644 dace/sdfg/analysis/schedule_tree/transformations.py diff --git a/dace/sdfg/analysis/schedule_tree/transformations.py b/dace/sdfg/analysis/schedule_tree/transformations.py new file mode 100644 index 0000000000..e6141fc1f8 --- /dev/null +++ b/dace/sdfg/analysis/schedule_tree/transformations.py @@ -0,0 +1,131 @@ +# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +from copy import deepcopy +from dace import data as dt, Memlet, SDFG +from dace.sdfg.analysis.schedule_tree import treenodes as tnodes +from dace.sdfg.analysis.schedule_tree import utils as tutils +from dace.sdfg.graph import NodeNotFoundError +from typing import Dict +from warnings import warn + + +_dataflow_nodes = (tnodes.ViewNode, tnodes.RefSetNode, tnodes.CopyNode, tnodes.DynScopeCopyNode, tnodes.TaskletNode, tnodes.LibraryCall) + + +def _update_memlets(data: Dict[str, dt.Data], memlets: Dict[str, Memlet], index: str, replace: Dict[str, bool]): + for conn, memlet in memlets.items(): + if memlet.data in data: + subset = index if replace[memlet.data] else f"{index}, {memlet.subset}" + memlets[conn] = Memlet(data=memlet.data, subset=subset) + + +def _augment_data(data: Dict[str, dt.Data], map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode, sdfg: SDFG): + + # Generate map-related indices, sizes and, strides + map = map_scope.node.map + index = ", ".join(map.params) + size = map.range.size() + strides = [1] * len(size) + for i in range(len(size) - 2, -1, -1): + strides[i] = strides[i+1] * size[i+1] + + # Augment data descriptors + replace = dict() + for name, desc in data.items(): + if isinstance(desc, dt.Scalar): + sdfg.arrays[name] = dt.Array(desc.dtype, size, True, storage=desc.storage) + replace[name] = True + else: + mult = desc.shape[0] + desc.shape = (*size, *desc.shape) + new_strides = [s * mult for s in strides] + desc.strides = (*new_strides, *desc.strides) + replace[name] = False + + # Update memlets + frontier = list(tree.children) + while frontier: + node = frontier.pop() + if isinstance(node, _dataflow_nodes): + try: + _update_memlets(data, node.in_memlets, index, replace) + _update_memlets(data, node.out_memlets, index, replace) + except AttributeError: + subset = index if replace[node.target] else f"{index}, {node.memlet.subset}" + node.memlet = Memlet(data=node.target, subset=subset) + if hasattr(node, 'children'): + frontier.extend(node.children) + + +def map_fission(map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode, sdfg: SDFG) -> bool: + """ + Applies the MapFission transformation to the input MapScope. + + :param map_scope: The MapScope. + :param tree: The ScheduleTree. + :param sdfg: The (top-level) SDFG. + :return: True if the transformation applies successfully, otherwise False. + """ + + #################################### + # Check if MapFission can be applied + + # Basic check: cannot fission an empty MapScope or one that has a single dataflow child. + num_children = len(map_scope.children) + if num_children == 0 or (num_children == 1 and isinstance(map_scope.children[0], _dataflow_nodes)): + return False + + # State-scope check: if the body consists of a single state-scope, certain conditions apply. + partition = tutils.partition_scope_body(map_scope) + if len(partition) == 1: + child = partition[0] + conditions = [] + if isinstance(child, list): + # If-Elif-Else-Scope + for c in child: + if isinstance(c, (tnodes.IfScope, tnodes.ElifScope)): + conditions.append(c.condition) + elif isinstance(child, tnodes.ForScope): + conditions.append(child.header.condition) + elif isinstance(child, tnodes.WhileScope): + conditions.append(child.header.test) + for cond in conditions: + map = map_scope.node.map + if any(p in cond.get_free_symbols() for p in map.params): + return False + # TODO: How to run the check below in the ScheduleTree? + # for s in cond.get_free_symbols(): + # for e in graph.edges_by_connector(self.nested_sdfg, s): + # if any(p in e.data.free_symbols for p in map.params): + # return False + + data_to_augment = dict() + for scope in partition: + if isinstance(scope, _dataflow_nodes): + try: + _, _, sd = tutils.find_tnode_in_sdfg(scope, sdfg) + for _, memlet in scope.out_memlets.items(): + data_to_augment[memlet.data] = sd.arrays[memlet.data] + except NodeNotFoundError: + warn(f"Tree node {scope} not found in SDFG {sdfg}. Switching to unsafe data lookup.") + for _, memlet in scope.out_memlets.items(): + for sd in sdfg.all_sdfgs_recursive(): + if memlet.data in sd.arrays: + data_to_augment[memlet.data] = sd.arrays[memlet.data] + break + except NotImplementedError: + warn(f"Tree node {scope} is unsupported. Switching to unsafe data lookup.") + for sd in sdfg.all_sdfgs_recursive(): + if scope.target in sd.arrays: + data_to_augment[scope.target] = sd.arrays[scope.target] + break + _augment_data(data_to_augment, map_scope, tree, sdfg) + + parent_scope = tutils.find_parent(map_scope, tree) + idx = parent_scope.children.index(map_scope) + parent_scope.children.pop(idx) + while len(partition) > 0: + child_scope = partition.pop() + if not isinstance(child_scope, list): + child_scope = [child_scope] + scope = tnodes.MapScope(child_scope, deepcopy(map_scope.node)) + parent_scope.children.insert(idx, scope) diff --git a/dace/sdfg/analysis/schedule_tree/utils.py b/dace/sdfg/analysis/schedule_tree/utils.py index 75c2785320..64ff9464db 100644 --- a/dace/sdfg/analysis/schedule_tree/utils.py +++ b/dace/sdfg/analysis/schedule_tree/utils.py @@ -3,9 +3,9 @@ from dace.sdfg import nodes as snodes from dace.sdfg.analysis.schedule_tree import treenodes as tnodes from dace.sdfg.graph import NodeNotFoundError -from typing import Tuple +from typing import List, Tuple, Union -__df_nodes = ( +_df_nodes = ( # tnodes.ViewNode, tnodes.RefSetNode, # tnodes.CopyNode, tnodes.DynScopeCopyNode, tnodes.TaskletNode, @@ -16,7 +16,7 @@ def find_tnode_in_sdfg(tnode: tnodes.ScheduleTreeNode, top_level_sdfg: SDFG) -> Tuple[snodes.Node, SDFGState, SDFG]: - if tnode not in __df_nodes: + if not isinstance(tnode, _df_nodes): raise NotImplementedError(f"The `find_dfnode_in_sdfg` does not support {type(tnode)} nodes.") for n, s in top_level_sdfg.all_nodes_recursive(): if n is tnode.node: @@ -41,6 +41,7 @@ def find_snode_in_tree(snode: snodes.Node, return pnode, cnode +# TODO: Add `parent` attributes to the tree-nodes to make the following obsolete def find_parent(tnode: tnodes.ScheduleTreeNode, tree: tnodes.ScheduleTreeNode) -> tnodes.ScheduleTreeScope: pnode = None frontier = [(tree, child) for child in tree.children] @@ -53,3 +54,32 @@ def find_parent(tnode: tnodes.ScheduleTreeNode, tree: tnodes.ScheduleTreeNode) - if not pnode: raise NodeNotFoundError(f"Node {tnode} not found in ScheduleTree.") return pnode + + +def partition_scope_body(scope: tnodes.ScheduleTreeScope) -> List[Union[tnodes.ScheduleTreeNode, List[tnodes.ScheduleTreeNode]]]: + """ + Partitions a scope's body to ScheduleTree nodes, when they define their own sub-scope, and lists of ScheduleTree + nodes that are children to the same sub-scope. For example, IfScopes, ElifScopes, and ElseScopes are generally + children to a general "If-Elif-Else-Scope". + + :param scope: The scope. + :return: A list of (lists of) ScheduleTree nodes. + """ + + num_children = len(scope.children) + partition = [] + i = 0 + while i < num_children: + child = scope.children[i] + if isinstance(child, tnodes.IfScope): + # Start If-Elif-Else-Scope. + ifelse = [child] + i += 1 + while i < num_children and isinstance(scope.children[i], (tnodes.ElifScope, tnodes.ElseScope)): + ifelse.append(child) + i += 1 + partition.append(ifelse) + else: + partition.append(child) + i += 1 + return partition From c1ca73cb048d818d99c81ea250d0d5b1e4c545ab Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Wed, 14 Dec 2022 22:47:31 +0100 Subject: [PATCH 14/98] Updated ScheduleTree MapFission transformation. --- .../analysis/schedule_tree/transformations.py | 53 ++++++++++--------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/transformations.py b/dace/sdfg/analysis/schedule_tree/transformations.py index e6141fc1f8..052591bc96 100644 --- a/dace/sdfg/analysis/schedule_tree/transformations.py +++ b/dace/sdfg/analysis/schedule_tree/transformations.py @@ -3,9 +3,7 @@ from dace import data as dt, Memlet, SDFG from dace.sdfg.analysis.schedule_tree import treenodes as tnodes from dace.sdfg.analysis.schedule_tree import utils as tutils -from dace.sdfg.graph import NodeNotFoundError from typing import Dict -from warnings import warn _dataflow_nodes = (tnodes.ViewNode, tnodes.RefSetNode, tnodes.CopyNode, tnodes.DynScopeCopyNode, tnodes.TaskletNode, tnodes.LibraryCall) @@ -30,9 +28,10 @@ def _augment_data(data: Dict[str, dt.Data], map_scope: tnodes.MapScope, tree: tn # Augment data descriptors replace = dict() - for name, desc in data.items(): + for name, nsdfg in data.items(): + desc = nsdfg.arrays[name] if isinstance(desc, dt.Scalar): - sdfg.arrays[name] = dt.Array(desc.dtype, size, True, storage=desc.storage) + nsdfg.arrays[name] = dt.Array(desc.dtype, size, True, storage=desc.storage) replace[name] = True else: mult = desc.shape[0] @@ -40,6 +39,13 @@ def _augment_data(data: Dict[str, dt.Data], map_scope: tnodes.MapScope, tree: tn new_strides = [s * mult for s in strides] desc.strides = (*new_strides, *desc.strides) replace[name] = False + if sdfg.parent: + nsdfg_node = nsdfg.parent_nsdfg_node + nsdfg_state = nsdfg.parent + nsdfg_node.out_connectors = {**nsdfg_node.out_connectors, name: None} + sdfg.arrays[name] = deepcopy(nsdfg.arrays[name]) + access = nsdfg_state.add_access(name) + nsdfg_state.add_edge(nsdfg_node, name, access, None, Memlet.from_array(name, sdfg.arrays[name])) # Update memlets frontier = list(tree.children) @@ -50,22 +56,23 @@ def _augment_data(data: Dict[str, dt.Data], map_scope: tnodes.MapScope, tree: tn _update_memlets(data, node.in_memlets, index, replace) _update_memlets(data, node.out_memlets, index, replace) except AttributeError: - subset = index if replace[node.target] else f"{index}, {node.memlet.subset}" - node.memlet = Memlet(data=node.target, subset=subset) + subset = index if replace[node.memlet.data] else f"{index}, {node.memlet.subset}" + node.memlet = Memlet(data=node.memlet.data, subset=subset) if hasattr(node, 'children'): frontier.extend(node.children) -def map_fission(map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode, sdfg: SDFG) -> bool: +def map_fission(map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode) -> bool: """ Applies the MapFission transformation to the input MapScope. :param map_scope: The MapScope. :param tree: The ScheduleTree. - :param sdfg: The (top-level) SDFG. :return: True if the transformation applies successfully, otherwise False. """ + sdfg = map_scope.sdfg + #################################### # Check if MapFission can be applied @@ -98,26 +105,20 @@ def map_fission(map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode, sdfg: # if any(p in e.data.free_symbols for p in map.params): # return False - data_to_augment = dict() - for scope in partition: + data_to_augment = dict() + frontier = list(partition) + while len(frontier) > 0: + scope = frontier.pop() if isinstance(scope, _dataflow_nodes): try: - _, _, sd = tutils.find_tnode_in_sdfg(scope, sdfg) - for _, memlet in scope.out_memlets.items(): - data_to_augment[memlet.data] = sd.arrays[memlet.data] - except NodeNotFoundError: - warn(f"Tree node {scope} not found in SDFG {sdfg}. Switching to unsafe data lookup.") for _, memlet in scope.out_memlets.items(): - for sd in sdfg.all_sdfgs_recursive(): - if memlet.data in sd.arrays: - data_to_augment[memlet.data] = sd.arrays[memlet.data] - break - except NotImplementedError: - warn(f"Tree node {scope} is unsupported. Switching to unsafe data lookup.") - for sd in sdfg.all_sdfgs_recursive(): - if scope.target in sd.arrays: - data_to_augment[scope.target] = sd.arrays[scope.target] - break + if scope.sdfg.arrays[memlet.data].transient: + data_to_augment[memlet.data] = scope.sdfg + except AttributeError: + if scope.target in scope.sdfg.arrays and scope.sdfg.arrays[scope.target].transient: + data_to_augment[scope.target] = scope.sdfg + if hasattr(scope, 'children'): + frontier.extend(scope.children) _augment_data(data_to_augment, map_scope, tree, sdfg) parent_scope = tutils.find_parent(map_scope, tree) @@ -127,5 +128,5 @@ def map_fission(map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode, sdfg: child_scope = partition.pop() if not isinstance(child_scope, list): child_scope = [child_scope] - scope = tnodes.MapScope(child_scope, deepcopy(map_scope.node)) + scope = tnodes.MapScope(sdfg, False, child_scope, deepcopy(map_scope.node)) parent_scope.children.insert(idx, scope) From 273da6c87a85c28e0614de603882ab7acdd81801 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Wed, 14 Dec 2022 22:48:45 +0100 Subject: [PATCH 15/98] Added methos for printing Data (Scalar, Array) as Python arguments. --- dace/data.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/dace/data.py b/dace/data.py index 85f491bc01..26053288e3 100644 --- a/dace/data.py +++ b/dace/data.py @@ -242,6 +242,10 @@ def __hash__(self): def as_arg(self, with_types=True, for_call=False, name=None): """Returns a string for a C++ function signature (e.g., `int *A`). """ raise NotImplementedError + + def as_python_arg(self, with_types=True, for_call=False, name=None): + """Returns a string for a Data-Centric Python function signature (e.g., `A: dace.int32[M]`). """ + raise NotImplementedError @property def free_symbols(self) -> Set[symbolic.SymbolicType]: @@ -419,6 +423,13 @@ def as_arg(self, with_types=True, for_call=False, name=None): if not with_types or for_call: return name return self.dtype.as_arg(name) + + def as_python_arg(self, with_types=True, for_call=False, name=None): + if self.storage is dtypes.StorageType.GPU_Global: + return Array(self.dtype, [1]).as_python_arg(with_types, for_call, name) + if not with_types or for_call: + return name + return f"{name}: {dtypes.TYPECLASS_TO_STRING[self.dtype].replace('::', '.')}" def sizes(self): return None @@ -689,6 +700,13 @@ def as_arg(self, with_types=True, for_call=False, name=None): if self.may_alias: return str(self.dtype.ctype) + ' *' + arrname return str(self.dtype.ctype) + ' * __restrict__ ' + arrname + + def as_python_arg(self, with_types=True, for_call=False, name=None): + arrname = name + + if not with_types or for_call: + return arrname + return f"{arrname}: {dtypes.TYPECLASS_TO_STRING[self.dtype].replace('::', '.')}{list(self.shape)}" def sizes(self): return [d.name if isinstance(d, symbolic.symbol) else str(d) for d in self.shape] From 25ccc6bd9037ce316a55bbf3e68a56fb53566229 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Wed, 14 Dec 2022 22:50:06 +0100 Subject: [PATCH 16/98] Minor bug (?) fixes related to parsing DaCe programs generated from ScheduleTrees. --- dace/frontend/python/newast.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 1ac9573212..927b9dea54 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -2038,7 +2038,7 @@ def _add_dependencies(self, if i == o and n not in inner_indices: outer_indices.append(n) elif n not in inner_indices: - inner_indices.add(n) + inner_indices.append(n) # Avoid the case where all indices are outer, # i.e., the whole array is carried through the nested SDFG levels. if len(outer_indices) < len(irng) or irng.num_elements() == 1: @@ -4857,6 +4857,10 @@ def visit_Subscript(self, node: ast.Subscript, inference: bool = False): # If this subscript originates from an external array, create the # subset in the edge going to the connector, as well as a local # reference to the subset + # old_node = node + # if isinstance(node.value, ast.Name): + # true_node = copy.deepcopy(old_node) + # true_node.value.id = true_name if (true_name not in self.sdfg.arrays and isinstance(node.value, ast.Name)): true_node = copy.deepcopy(node) true_node.value.id = true_name @@ -4872,7 +4876,7 @@ def visit_Subscript(self, node: ast.Subscript, inference: bool = False): if inference: rng.offset(rng, True) return self.sdfg.arrays[true_name].dtype, rng.size() - new_name, new_rng = self._add_read_access(name, rng, node) + new_name, new_rng = self._add_read_access(true_name, rng, node) new_arr = self.sdfg.arrays[new_name] full_rng = subsets.Range.from_array(new_arr) if new_rng.ranges == full_rng.ranges: From d666e55cf57ca24b82c2510af0af798bbb747f82 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Wed, 14 Dec 2022 22:50:47 +0100 Subject: [PATCH 17/98] Added replacement methods for ScheduleTree Copy, Library, and Tasklet nodes. --- dace/frontend/python/replacements.py | 66 ++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index 03c6119ea4..afffc8aaff 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -4619,3 +4619,69 @@ def _op(visitor: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, op1: StringLite for op, method in _boolop_to_method.items(): _makeboolop(op, method) + + +# ScheduleTree-related replacements ################################################################################### + + +@oprepo.replaces('dace.tree.tasklet') +def _tasklet(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, label: StringLiteral, inputs: Dict[StringLiteral, str], outputs: Dict[StringLiteral, str], code: StringLiteral, language: dtypes.Language): + label = label.value + inputs = {k.value: v for k, v in inputs.items()} + outputs = {k.value: v for k, v in outputs.items()} + code = code.value + tasklet = state.add_tasklet(label, inputs.keys(), outputs.keys(), code, language) + for conn, name in inputs.items(): + access = state.add_access(name) + state.add_edge(access, None, tasklet, conn, Memlet.from_array(name, sdfg.arrays[name])) + for conn, name in outputs.items(): + access = state.add_access(name) + state.add_edge(tasklet, conn, access, None, Memlet.from_array(name, sdfg.arrays[name])) + + # Handle scope output + for out in outputs.values(): + for (outer_var, outer_rng, _), (var, rng) in pv.accesses.items(): + if out == var: + if not (outer_var, outer_rng, 'w') in pv.accesses: + pv.accesses[(outer_var, outer_rng, 'w')] = (out, rng) + pv.outputs[out] = (state, Memlet(data=outer_var, subset=outer_rng), []) + break + + +@oprepo.replaces('dace.tree.library') +def _library(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, ltype: type, label: StringLiteral, inputs: Dict[StringLiteral, str], outputs: Dict[StringLiteral, str], **kwargs): + label = label.value + inputs = {k.value: v for k, v in inputs.items()} + outputs = {k.value: v for k, v in outputs.items()} + tasklet = ltype(label, **kwargs) + state.add_node(tasklet) + for conn, name in inputs.items(): + access = state.add_access(name) + state.add_edge(access, None, tasklet, conn, Memlet.from_array(name, sdfg.arrays[name])) + for conn, name in outputs.items(): + access = state.add_access(name) + memlet = Memlet.from_array(name, sdfg.arrays[name]) + state.add_edge(tasklet, conn, access, None, memlet) + + # Handle scope output + for out in outputs.values(): + for (outer_var, outer_rng, _), (var, rng) in pv.accesses.items(): + if out == var: + if not (outer_var, outer_rng, 'w') in pv.accesses: + pv.accesses[(outer_var, outer_rng, 'w')] = (out, rng) + pv.outputs[out] = (state, Memlet(data=outer_var, subset=outer_rng), []) + break + + +@oprepo.replaces('dace.tree.copy') +def _tasklet(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, src: str, dst: str, wcr: str = None): + src_access = state.add_access(src) + dst_access = state.add_access(dst) + state.add_nedge(src_access, dst_access, Memlet.from_array(dst, sdfg.arrays[dst], wcr=None)) + + for (outer_var, outer_rng, _), (var, rng) in pv.accesses.items(): + if dst == var: + if not (outer_var, outer_rng, 'w') in pv.accesses: + pv.accesses[(outer_var, outer_rng, 'w')] = (dst, rng) + pv.outputs[dst] = (state, Memlet(data=outer_var, subset=outer_rng), []) + break From 7489e63c63e59f4d652df69f5a747a5d4c527340 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Wed, 14 Dec 2022 22:51:15 +0100 Subject: [PATCH 18/98] Added the alpha and beta properties to MatMul's init method. --- dace/libraries/blas/nodes/matmul.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dace/libraries/blas/nodes/matmul.py b/dace/libraries/blas/nodes/matmul.py index a937af0a81..859dc96675 100644 --- a/dace/libraries/blas/nodes/matmul.py +++ b/dace/libraries/blas/nodes/matmul.py @@ -210,5 +210,7 @@ class MatMul(dace.sdfg.nodes.LibraryNode): default=0, desc="A scalar which will be multiplied with C before adding C") - def __init__(self, name, location=None): + def __init__(self, name, location=None, alpha=1, beta=0): + self.alpha = alpha + self.beta = beta super().__init__(name, location=location, inputs={"_a", "_b"}, outputs={"_c"}) From 63ad7c5d5a79159ce38c21d401a74265b73e8609 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Wed, 14 Dec 2022 22:52:09 +0100 Subject: [PATCH 19/98] Added methods to SDFG that generated DaCe Python signatures. --- dace/sdfg/sdfg.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 011ea56922..a91774b958 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -1365,6 +1365,21 @@ def signature_arglist(self, with_types=True, for_call=False, with_arrays=True, a """ arglist = arglist or self.arglist(scalars_only=not with_arrays) return [v.as_arg(name=k, with_types=with_types, for_call=for_call) for k, v in arglist.items()] + + def python_signature_arglist(self, with_types=True, for_call=False, with_arrays=True, arglist=None) -> List[str]: + """ Returns a list of arguments necessary to call this SDFG, + formatted as a list of Data-Centric Python definitions. + + :param with_types: If True, includes argument types in the result. + :param for_call: If True, returns arguments that can be used when + calling the SDFG. + :param with_arrays: If True, includes arrays, otherwise, + only symbols and scalars are included. + :param arglist: An optional cached argument list. + :return: A list of strings. For example: `['A: dace.float32[M]', 'b: dace.int32']`. + """ + arglist = arglist or self.arglist(scalars_only=not with_arrays, free_symbols=[]) + return [v.as_python_arg(name=k, with_types=with_types, for_call=for_call) for k, v in arglist.items()] def signature(self, with_types=True, for_call=False, with_arrays=True, arglist=None) -> str: """ Returns a C/C++ signature of this SDFG, used when generating code. @@ -1380,6 +1395,21 @@ def signature(self, with_types=True, for_call=False, with_arrays=True, arglist=N :param arglist: An optional cached argument list. """ return ", ".join(self.signature_arglist(with_types, for_call, with_arrays, arglist)) + + def python_signature(self, with_types=True, for_call=False, with_arrays=True, arglist=None) -> str: + """ Returns a Data-Centric Python signature of this SDFG, used when generating code. + + :param with_types: If True, includes argument types (can be used + for a function prototype). If False, only + include argument names (can be used for function + calls). + :param for_call: If True, returns arguments that can be used when + calling the SDFG. + :param with_arrays: If True, includes arrays, otherwise, + only symbols and scalars are included. + :param arglist: An optional cached argument list. + """ + return ", ".join(self.python_signature_arglist(with_types, for_call, with_arrays, arglist)) def _repr_html_(self): """ HTML representation of the SDFG, used mainly for Jupyter From 5ab2eb410642c6be39d6601cfdd81e1aed3eefff Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Wed, 14 Dec 2022 22:56:57 +0100 Subject: [PATCH 20/98] Updated the ScheduleTree nodes to point to the corresponding (nested) SDFG. Added methods for printing pseudo-DaCe-Python code. Updated the `as_schedule_tree` method to unsqueeze memlets and rename arrays to match the names in the top-level SDFG. --- .../analysis/schedule_tree/sdfg_to_tree.py | 61 ++++++++++--- dace/sdfg/analysis/schedule_tree/treenodes.py | 88 ++++++++++++++++++- 2 files changed, 132 insertions(+), 17 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index e4ef82014c..c51a3fd8b2 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -4,11 +4,12 @@ import dace from dace import symbolic, data from dace.codegen import control_flow as cf -from dace.sdfg.sdfg import SDFG +from dace.sdfg.sdfg import InterstateEdge, SDFG from dace.sdfg.state import SDFGState from dace.sdfg import utils as sdutil, graph as gr from dace.frontend.python.astutils import negate_expr -from dace.sdfg.analysis.schedule_tree import treenodes as tn, passes as stpasses +from dace.sdfg.analysis.schedule_tree import treenodes as tn, passes as stpasses, utils as tutils +from dace.transformation.helpers import unsqueeze_memlet from dace.properties import CodeBlock from dace.memlet import Memlet @@ -50,8 +51,11 @@ def replace_memlets(sdfg: SDFG, array_mapping: Dict[str, Memlet]): :param sdfg: The SDFG. :param array_mapping: A mapping from internal data descriptor names to external memlets. """ - # TODO replace, normalize, and compose - pass + # TODO: Support Interstate edges + for state in sdfg.states(): + for e in state.edges(): + if e.data.data in array_mapping: + e.data = unsqueeze_memlet(e.data, array_mapping[e.data.data]) def remove_name_collisions(sdfg: SDFG): @@ -75,8 +79,27 @@ def remove_name_collisions(sdfg: SDFG): # Rename duplicate data containers for name, desc in nsdfg.arrays.items(): - # Will already be renamed during conversion - if parent_node is not None and not desc.transient: + # TODO: Is it better to do this while parsing the SDFG? + pdesc = desc + pnode = parent_node + csdfg = nsdfg + cname = name + while pnode is not None and not pdesc.transient: + parent_state = csdfg.parent + parent_sdfg = csdfg.parent_sdfg + edge = list(parent_state.edges_by_connector(parent_node, cname))[0] + path = parent_state.memlet_path(edge) + if path[0].src is parent_node: + parent_name = path[-1].dst.data + else: + parent_name = path[0].src.data + pdesc = parent_sdfg.arrays[parent_name] + csdfg = parent_sdfg + pnode = csdfg.parent_nsdfg_node + cname = parent_name + if pnode is None and not pdesc.transient and name != cname: + replacements[name] = cname + name = cname continue if name in identifiers_seen: @@ -108,6 +131,17 @@ def remove_name_collisions(sdfg: SDFG): # If there is a name collision, replace all uses of the old names with the new names if replacements: nsdfg.replace_dict(replacements) + # TODO: Should this be handled differently? + # Replacing connector names + # Replacing edge connector names + if nsdfg.parent_sdfg: + nsdfg.parent_nsdfg_node.in_connectors = {replacements[c]: t for c, t in nsdfg.parent_nsdfg_node.in_connectors.items()} + nsdfg.parent_nsdfg_node.out_connectors = {replacements[c]: t for c, t in nsdfg.parent_nsdfg_node.out_connectors.items()} + for e in nsdfg.parent.all_edges(nsdfg.parent_nsdfg_node): + if e.src_conn in replacements: + e._src_conn = replacements[e.src_conn] + elif e.dst_conn in replacements: + e._dst_conn = replacements[e.dst_conn] def _make_view_node(state: SDFGState, edge: gr.MultiConnectorEdge[Memlet], view_name: str, @@ -206,7 +240,7 @@ def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[ else: target_name = innermost_node.data new_memlet = normalize_memlet(sdfg, state, e, outermost_node.data) - result[e] = tn.CopyNode(target=target_name, memlet=new_memlet) + result[e] = tn.CopyNode(sdfg=sdfg, target=target_name, memlet=new_memlet) return result @@ -243,7 +277,7 @@ def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: # Create scope node and add to stack scopes.append(result) subnodes = [] - result.append(NODE_TO_SCOPE_TYPE[type(node)](node=node, children=subnodes)) + result.append(NODE_TO_SCOPE_TYPE[type(node)](node=node, sdfg=state.parent, top_level=False, children=subnodes)) result = subnodes elif isinstance(node, dace.nodes.ExitNode): result = scopes.pop() @@ -253,7 +287,6 @@ def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: # Replace symbols and memlets in nested SDFGs to match the namespace of the parent SDFG # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes symbolic.safe_replace(node.symbol_mapping, node.sdfg.replace_dict) - replace_memlets(node.sdfg, nested_array_mapping) # Create memlets for nested SDFG mapping, or nview schedule nodes if slice cannot be determined for e in state.all_edges(node): @@ -279,17 +312,19 @@ def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: src_desc=sdfg.arrays[e.data.data], view_desc=node.sdfg.arrays[conn])) + replace_memlets(node.sdfg, nested_array_mapping) + # Insert the nested SDFG flattened nested_stree = as_schedule_tree(node.sdfg, in_place=True, toplevel=False) result.extend(nested_stree.children) elif isinstance(node, dace.nodes.Tasklet): in_memlets = {e.dst_conn: e.data for e in state.in_edges(node) if e.dst_conn} out_memlets = {e.src_conn: e.data for e in state.out_edges(node) if e.src_conn} - result.append(tn.TaskletNode(node=node, in_memlets=in_memlets, out_memlets=out_memlets)) + result.append(tn.TaskletNode(sdfg=sdfg, node=node, in_memlets=in_memlets, out_memlets=out_memlets)) elif isinstance(node, dace.nodes.LibraryNode): in_memlets = {e.dst_conn: e.data for e in state.in_edges(node) if e.dst_conn} out_memlets = {e.src_conn: e.data for e in state.out_edges(node) if e.src_conn} - result.append(tn.LibraryCall(node=node, in_memlets=in_memlets, out_memlets=out_memlets)) + result.append(tn.LibraryCall(sdfg=sdfg, node=node, in_memlets=in_memlets, out_memlets=out_memlets)) elif isinstance(node, dace.nodes.AccessNode): # If one of the neighboring edges has a schedule tree node attached to it, use that for e in state.all_edges(node): @@ -412,12 +447,12 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche raise tn.UnsupportedScopeException(type(node).__name__) if node.first_state is not None: - result = [tn.StateLabel(state=node.first_state)] + result + result = [tn.StateLabel(sdfg=node.first_state.parent, state=node.first_state)] + result return result # Recursive traversal of the control flow tree - result = tn.ScheduleTreeScope(children=totree(cfg)) + result = tn.ScheduleTreeScope(sdfg=sdfg, top_level=True, children=totree(cfg)) # Clean up tree stpasses.remove_unused_and_duplicate_labels(result) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 2e17fa378a..8f362e4cf1 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -1,12 +1,13 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. -from dataclasses import dataclass +from dataclasses import dataclass, field from dace import nodes, data, subsets from dace.codegen import control_flow as cf +from dace.dtypes import TYPECLASS_TO_STRING from dace.properties import CodeBlock from dace.sdfg import SDFG from dace.sdfg.state import SDFGState from dace.memlet import Memlet -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Set, Tuple INDENTATION = ' ' @@ -17,20 +18,61 @@ class UnsupportedScopeException(Exception): @dataclass class ScheduleTreeNode: + sdfg: SDFG + parent: Optional['ScheduleTreeScope'] = field(default=None, init=False) def as_string(self, indent: int = 0): return indent * INDENTATION + 'UNSUPPORTED' + + def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: + string, defined_arrays = self.define_arrays(indent, defined_arrays) + return string + indent * INDENTATION + 'UNSUPPORTED', defined_arrays + + def define_arrays(self, indent: int, defined_arrays: Set[str]) -> Tuple[str, Set[str]]: + defined_arrays = defined_arrays or set() + string = '' + undefined_arrays = {name: desc for name, desc in self.sdfg.arrays.items() if not name in defined_arrays and desc.transient} + for name, desc in undefined_arrays.items(): + string += indent * INDENTATION + f"{name} = numpy.ndarray({desc.shape}, {TYPECLASS_TO_STRING[desc.dtype].replace('::', '.')})\n" + defined_arrays |= undefined_arrays.keys() + return string, defined_arrays @dataclass -class ScheduleTreeScope(ScheduleTreeNode): +class ScheduleTreeScope(ScheduleTreeNode): + top_level: bool children: List['ScheduleTreeNode'] - def __init__(self, children: Optional[List['ScheduleTreeNode']] = None): + def __init__(self, sdfg: Optional[SDFG] = None, top_level: Optional[bool] = False, children: Optional[List['ScheduleTreeNode']] = None): + self.sdfg = sdfg + self.top_level = top_level self.children = children or [] + for child in children: + child.parent = self def as_string(self, indent: int = 0): return '\n'.join([child.as_string(indent + 1) for child in self.children]) + + def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: + if self.top_level: + header = '' + for s in self.sdfg.free_symbols: + header += f"{s} = dace.symbol('{s}', {TYPECLASS_TO_STRING[self.sdfg.symbols[s]].replace('::', '.')})" + header += f""" +@dace.program +def {self.sdfg.label}({self.sdfg.python_signature()}): +""" + defined_arrays = set([name for name, desc in self.sdfg.arrays.items() if not desc.transient]) + else: + header = '' + defined_arrays = defined_arrays or set() + string, defined_arrays = self.define_arrays(indent + 1, defined_arrays) + for child in self.children: + substring, defined_arrays = child.as_python(indent + 1, defined_arrays) + string += substring + if string[-1] != '\n': + string += '\n' + return header + string, defined_arrays # TODO: Get input/output memlets? @@ -206,6 +248,12 @@ def as_string(self, indent: int = 0): result = indent * INDENTATION + f'map {", ".join(self.node.map.params)} in [{rangestr}]:\n' return result + super().as_string(indent) + def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: + rangestr = ', '.join(subsets.Range.dim_to_string(d) for d in self.node.map.range) + result = indent * INDENTATION + f'for {", ".join(self.node.map.params)} in dace.map[{rangestr}]:\n' + string, defined_arrays = super().as_python(indent, defined_arrays) + return result + string, defined_arrays + @dataclass class ConsumeScope(DataflowScope): @@ -243,6 +291,13 @@ def as_string(self, indent: int = 0): out_memlets = ', '.join(f'{v}' for v in self.out_memlets.values()) return indent * INDENTATION + f'{out_memlets} = tasklet({in_memlets})' + def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: + in_memlets = ', '.join(f"'{k}': {v}" for k, v in self.in_memlets.items()) + out_memlets = ', '.join(f"'{k}': {v}" for k, v in self.out_memlets.items()) + defined_arrays = defined_arrays or set() + string, defined_arrays = self.define_arrays(indent, defined_arrays) + return string + indent * INDENTATION + f"dace.tree.tasklet(label='{self.node.label}', inputs={{{in_memlets}}}, outputs={{{out_memlets}}}, code='{self.node.code.as_string}', language=dace.{self.node.language})", defined_arrays + @dataclass class LibraryCall(ScheduleTreeNode): @@ -259,6 +314,17 @@ def as_string(self, indent: int = 0): if v.owner not in {nodes.Node, nodes.CodeNode, nodes.LibraryNode}) return indent * INDENTATION + f'{out_memlets} = library {libname}[{own_properties}]({in_memlets})' + def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: + in_memlets = ', '.join(f"'{k}': {v}" for k, v in self.in_memlets.items()) + out_memlets = ', '.join(f"'{k}': {v}" for k, v in self.out_memlets.items()) + libname = type(self.node).__module__ + '.' + type(self.node).__qualname__ + # Get the properties of the library node without its superclasses + own_properties = ', '.join(f'{k}={getattr(self.node, k)}' for k, v in self.node.__properties__.items() + if v.owner not in {nodes.Node, nodes.CodeNode, nodes.LibraryNode}) + defined_arrays = defined_arrays or set() + string, defined_arrays = self.define_arrays(indent, defined_arrays) + return string + indent * INDENTATION + f"dace.tree.library(ltype={libname}, label='{self.node.label}', inputs={{{in_memlets}}}, outputs={{{out_memlets}}}, {own_properties})", defined_arrays + @dataclass class CopyNode(ScheduleTreeNode): @@ -276,6 +342,20 @@ def as_string(self, indent: int = 0): wcr = '' return indent * INDENTATION + f'{self.target}{offset} = copy {self.memlet.data}[{self.memlet.subset}]{wcr}' + + def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: + if self.memlet.other_subset is not None and any(s != 0 for s in self.memlet.other_subset.min_element()): + offset = f'[{self.memlet.other_subset}]' + else: + offset = f'[{self.memlet.subset}]' + if self.memlet.wcr is not None: + wcr = f' with {self.memlet.wcr}' + else: + wcr = '' + + defined_arrays = defined_arrays or set() + string, defined_arrays = self.define_arrays(indent, defined_arrays) + return string + indent * INDENTATION + f'dace.tree.copy(src={self.memlet.data}[{self.memlet.subset}], dst={self.target}{offset}, wcr={self.memlet.wcr})', defined_arrays @dataclass From c526195a66eb64fa2857f26f377d3d71ded300af Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Wed, 14 Dec 2022 22:58:22 +0100 Subject: [PATCH 21/98] Removed obsolete ScheduleTree utility methods. --- .../analysis/schedule_tree/transformations.py | 2 +- dace/sdfg/analysis/schedule_tree/utils.py | 50 ------------------- 2 files changed, 1 insertion(+), 51 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/transformations.py b/dace/sdfg/analysis/schedule_tree/transformations.py index 052591bc96..088248a2dc 100644 --- a/dace/sdfg/analysis/schedule_tree/transformations.py +++ b/dace/sdfg/analysis/schedule_tree/transformations.py @@ -121,7 +121,7 @@ def map_fission(map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode) -> bo frontier.extend(scope.children) _augment_data(data_to_augment, map_scope, tree, sdfg) - parent_scope = tutils.find_parent(map_scope, tree) + parent_scope = map_scope.parent idx = parent_scope.children.index(map_scope) parent_scope.children.pop(idx) while len(partition) > 0: diff --git a/dace/sdfg/analysis/schedule_tree/utils.py b/dace/sdfg/analysis/schedule_tree/utils.py index 64ff9464db..2257c203c8 100644 --- a/dace/sdfg/analysis/schedule_tree/utils.py +++ b/dace/sdfg/analysis/schedule_tree/utils.py @@ -5,56 +5,6 @@ from dace.sdfg.graph import NodeNotFoundError from typing import List, Tuple, Union -_df_nodes = ( - # tnodes.ViewNode, tnodes.RefSetNode, - # tnodes.CopyNode, tnodes.DynScopeCopyNode, - tnodes.TaskletNode, - tnodes.LibraryCall, - tnodes.MapScope, - tnodes.ConsumeScope, - tnodes.PipelineScope) - - -def find_tnode_in_sdfg(tnode: tnodes.ScheduleTreeNode, top_level_sdfg: SDFG) -> Tuple[snodes.Node, SDFGState, SDFG]: - if not isinstance(tnode, _df_nodes): - raise NotImplementedError(f"The `find_dfnode_in_sdfg` does not support {type(tnode)} nodes.") - for n, s in top_level_sdfg.all_nodes_recursive(): - if n is tnode.node: - return n, s, s.parent - raise NodeNotFoundError(f"Node {tnode} not found in SDFG.") - - -def find_snode_in_tree(snode: snodes.Node, - tree: tnodes.ScheduleTreeNode) -> Tuple[tnodes.ScheduleTreeScope, tnodes.ScheduleTreeNode]: - pnode = None - cnode = None - frontier = [(tree, child) for child in tree.children] - while frontier: - parent, child = frontier.pop() - if hasattr(child, 'node') and child.node is snode: - pnode = parent - cnode = child - break - frontier.extend([(child, c) for c in child.children]) - if not pnode: - raise NodeNotFoundError(f"Node {snode} not found in ScheduleTree.") - return pnode, cnode - - -# TODO: Add `parent` attributes to the tree-nodes to make the following obsolete -def find_parent(tnode: tnodes.ScheduleTreeNode, tree: tnodes.ScheduleTreeNode) -> tnodes.ScheduleTreeScope: - pnode = None - frontier = [(tree, child) for child in tree.children] - while frontier: - parent, child = frontier.pop() - if child is tnode: - pnode = parent - break - frontier.extend([(child, c) for c in child.children]) - if not pnode: - raise NodeNotFoundError(f"Node {tnode} not found in ScheduleTree.") - return pnode - def partition_scope_body(scope: tnodes.ScheduleTreeScope) -> List[Union[tnodes.ScheduleTreeNode, List[tnodes.ScheduleTreeNode]]]: """ From 748f01059cf6ea02534433563da0d33d14f2b757 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Wed, 14 Dec 2022 23:33:48 +0100 Subject: [PATCH 22/98] Added `as_sdfg` method and tests. --- .../analysis/schedule_tree/sdfg_to_tree.py | 29 +++++++++++++++ .../analysis/schedule_tree/transformations.py | 2 + tests/sdfg/schedule_tree/conversion_test.py | 35 ++++++++++++++++++ .../schedule_tree/map_fission_test.py | 37 +++++++++++++++++++ 4 files changed, 103 insertions(+) create mode 100644 tests/sdfg/schedule_tree/conversion_test.py create mode 100644 tests/transformations/schedule_tree/map_fission_test.py diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index c51a3fd8b2..f39ac664ac 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -460,6 +460,35 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche return result +def as_sdfg(tree: tn.ScheduleTreeScope) -> SDFG: + """ + Converts a ScheduleTree to its SDFG representation. + + :param tree: The ScheduleTree + :return: The ScheduleTree's SDFG representation + """ + + # Write tree as DaCe Python code. + code, _ = tree.as_python() + + # Save DaCe Python code to temporary file. + import tempfile + tmp = tempfile.NamedTemporaryFile(suffix='.py', delete=False) + tmp.write(b'import dace\n') + tmp.write(b'import numpy\n') + tmp.write(bytes(code, encoding='utf-8')) + tmp.close() + + # Load DaCe Python program from temporary file. + import importlib.util + spec = importlib.util.spec_from_file_location(tmp.name.split('/')[-1][:-3], tmp.name) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + prog = eval(f"mod.{tree.sdfg.label}") + + return prog.to_sdfg() + + if __name__ == '__main__': s = time.time() sdfg = SDFG.from_file(sys.argv[1]) diff --git a/dace/sdfg/analysis/schedule_tree/transformations.py b/dace/sdfg/analysis/schedule_tree/transformations.py index 088248a2dc..3551b6bf7f 100644 --- a/dace/sdfg/analysis/schedule_tree/transformations.py +++ b/dace/sdfg/analysis/schedule_tree/transformations.py @@ -130,3 +130,5 @@ def map_fission(map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode) -> bo child_scope = [child_scope] scope = tnodes.MapScope(sdfg, False, child_scope, deepcopy(map_scope.node)) parent_scope.children.insert(idx, scope) + + return True diff --git a/tests/sdfg/schedule_tree/conversion_test.py b/tests/sdfg/schedule_tree/conversion_test.py new file mode 100644 index 0000000000..2bd1f227f2 --- /dev/null +++ b/tests/sdfg/schedule_tree/conversion_test.py @@ -0,0 +1,35 @@ +# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +import dace +import numpy as np +from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree, as_sdfg + + +# TODO: The test fails because of the ambiguity when having access nodes inside a MapScope but on the same SDFG level. +def test_map_with_tasklet_and_library(): + + N = dace.symbol('N') + @dace.program + def map_with_tasklet_and_library(A: dace.float32[N, 5, 5], B: dace.float32[N, 5, 5], cst: dace.int32): + out = np.ndarray((N, 5, 5), dtype=dace.float32) + for i in dace.map[0:N]: + out[i] = cst * (A[i] @ B[i]) + return out + + rng = np.random.default_rng(42) + A = rng.random((10, 5, 5), dtype=np.float32) + B = rng.random((10, 5, 5), dtype=np.float32) + cst = rng.integers(0, 100, dtype=np.int32) + ref = cst * (A @ B) + + val0 = map_with_tasklet_and_library(A, B, cst) + sdfg0 = map_with_tasklet_and_library.to_sdfg() + tree = as_schedule_tree(sdfg0) + sdfg1 = as_sdfg(tree) + val1 = sdfg1(A=A, B=B, cst=cst, N=A.shape[0]) + + assert np.allclose(val0, ref) + assert np.allclose(val1, ref) + + +if __name__ == "__main__": + test_map_with_tasklet_and_library() diff --git a/tests/transformations/schedule_tree/map_fission_test.py b/tests/transformations/schedule_tree/map_fission_test.py new file mode 100644 index 0000000000..addd0c2a3a --- /dev/null +++ b/tests/transformations/schedule_tree/map_fission_test.py @@ -0,0 +1,37 @@ +# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +import dace +import numpy as np +from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree, as_sdfg +from dace.sdfg.analysis.schedule_tree.transformations import map_fission + + +def test_map_with_tasklet_and_library(): + + N = dace.symbol('N') + @dace.program + def map_with_tasklet_and_library(A: dace.float32[N, 5, 5], B: dace.float32[N, 5, 5], cst: dace.int32): + out = np.ndarray((N, 5, 5), dtype=dace.float32) + for i in dace.map[0:N]: + out[i] = cst * (A[i] @ B[i]) + return out + + rng = np.random.default_rng(42) + A = rng.random((10, 5, 5), dtype=np.float32) + B = rng.random((10, 5, 5), dtype=np.float32) + cst = rng.integers(0, 100, dtype=np.int32) + ref = cst * (A @ B) + + val0 = map_with_tasklet_and_library(A, B, cst) + sdfg0 = map_with_tasklet_and_library.to_sdfg() + tree = as_schedule_tree(sdfg0) + result = map_fission(tree.children[0], tree) + assert result + sdfg1 = as_sdfg(tree) + val1 = sdfg1(A=A, B=B, cst=cst, N=A.shape[0]) + + assert np.allclose(val0, ref) + assert np.allclose(val1, ref) + + +if __name__ == "__main__": + test_map_with_tasklet_and_library() From 0f57358110567362a509831e83f0891ed482095e Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Thu, 22 Dec 2022 14:41:24 +0100 Subject: [PATCH 23/98] `inner_indices` is a set --- dace/frontend/python/newast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 927b9dea54..87f91bc3df 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -2038,7 +2038,7 @@ def _add_dependencies(self, if i == o and n not in inner_indices: outer_indices.append(n) elif n not in inner_indices: - inner_indices.append(n) + inner_indices.add(n) # Avoid the case where all indices are outer, # i.e., the whole array is carried through the nested SDFG levels. if len(outer_indices) < len(irng) or irng.num_elements() == 1: From bed168eb0425a5193efe37aeb61a62f687bda746 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Thu, 22 Dec 2022 15:04:56 +0100 Subject: [PATCH 24/98] Empty `inner_indices` should be an empty set. --- dace/frontend/python/replacements.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index afffc8aaff..fc6333ccbc 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -4644,7 +4644,7 @@ def _tasklet(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, label: StringLi if out == var: if not (outer_var, outer_rng, 'w') in pv.accesses: pv.accesses[(outer_var, outer_rng, 'w')] = (out, rng) - pv.outputs[out] = (state, Memlet(data=outer_var, subset=outer_rng), []) + pv.outputs[out] = (state, Memlet(data=outer_var, subset=outer_rng), set()) break @@ -4669,7 +4669,7 @@ def _library(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, ltype: type, la if out == var: if not (outer_var, outer_rng, 'w') in pv.accesses: pv.accesses[(outer_var, outer_rng, 'w')] = (out, rng) - pv.outputs[out] = (state, Memlet(data=outer_var, subset=outer_rng), []) + pv.outputs[out] = (state, Memlet(data=outer_var, subset=outer_rng), set()) break @@ -4683,5 +4683,5 @@ def _tasklet(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, src: str, dst: if dst == var: if not (outer_var, outer_rng, 'w') in pv.accesses: pv.accesses[(outer_var, outer_rng, 'w')] = (dst, rng) - pv.outputs[dst] = (state, Memlet(data=outer_var, subset=outer_rng), []) + pv.outputs[dst] = (state, Memlet(data=outer_var, subset=outer_rng), set()) break From 5e6ee5c437c9cd2f7cbe37e88d60b24c8a9f5207 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Thu, 22 Dec 2022 15:31:13 +0100 Subject: [PATCH 25/98] Added `is_data_used` method to assist identifying MapScope private variables. --- dace/sdfg/analysis/schedule_tree/treenodes.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 8f362e4cf1..0b4173d2d9 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -32,10 +32,23 @@ def define_arrays(self, indent: int, defined_arrays: Set[str]) -> Tuple[str, Set defined_arrays = defined_arrays or set() string = '' undefined_arrays = {name: desc for name, desc in self.sdfg.arrays.items() if not name in defined_arrays and desc.transient} + if hasattr(self, 'children'): + times_used = {name: 0 for name in undefined_arrays} + for child in self.children: + for name in undefined_arrays: + if child.is_data_used(name): + times_used[name] += 1 + undefined_arrays = {name: desc for name, desc in undefined_arrays.items() if times_used[name] > 1} for name, desc in undefined_arrays.items(): string += indent * INDENTATION + f"{name} = numpy.ndarray({desc.shape}, {TYPECLASS_TO_STRING[desc.dtype].replace('::', '.')})\n" defined_arrays |= undefined_arrays.keys() return string, defined_arrays + + def is_data_used(self, name: str) -> bool: + for child in self.children: + if child.is_data_used(name): + return True + return False @dataclass @@ -298,6 +311,9 @@ def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[s string, defined_arrays = self.define_arrays(indent, defined_arrays) return string + indent * INDENTATION + f"dace.tree.tasklet(label='{self.node.label}', inputs={{{in_memlets}}}, outputs={{{out_memlets}}}, code='{self.node.code.as_string}', language=dace.{self.node.language})", defined_arrays + def is_data_used(self, name: str) -> bool: + return name in self.in_memlets.keys() | self.out_memlets.keys() + @dataclass class LibraryCall(ScheduleTreeNode): @@ -325,6 +341,9 @@ def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[s string, defined_arrays = self.define_arrays(indent, defined_arrays) return string + indent * INDENTATION + f"dace.tree.library(ltype={libname}, label='{self.node.label}', inputs={{{in_memlets}}}, outputs={{{out_memlets}}}, {own_properties})", defined_arrays + def is_data_used(self, name: str) -> bool: + return name in self.in_memlets.keys() | self.out_memlets.keys() + @dataclass class CopyNode(ScheduleTreeNode): From b4b698dc81a50dd99192d5ed7986aa0703ebf962 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Thu, 22 Dec 2022 15:42:27 +0100 Subject: [PATCH 26/98] Fixed `is_data_used` for TaskletNodes and LibraryCalls. --- dace/sdfg/analysis/schedule_tree/treenodes.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 0b4173d2d9..f456176745 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -312,7 +312,9 @@ def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[s return string + indent * INDENTATION + f"dace.tree.tasklet(label='{self.node.label}', inputs={{{in_memlets}}}, outputs={{{out_memlets}}}, code='{self.node.code.as_string}', language=dace.{self.node.language})", defined_arrays def is_data_used(self, name: str) -> bool: - return name in self.in_memlets.keys() | self.out_memlets.keys() + used_data = set([memlet.data for memlet in self.in_memlets.values()]) + used_data |= set([memlet.data for memlet in self.out_memlets.values()]) + return name in used_data @dataclass @@ -342,7 +344,9 @@ def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[s return string + indent * INDENTATION + f"dace.tree.library(ltype={libname}, label='{self.node.label}', inputs={{{in_memlets}}}, outputs={{{out_memlets}}}, {own_properties})", defined_arrays def is_data_used(self, name: str) -> bool: - return name in self.in_memlets.keys() | self.out_memlets.keys() + used_data = set([memlet.data for memlet in self.in_memlets.values()]) + used_data |= set([memlet.data for memlet in self.out_memlets.values()]) + return name in used_data @dataclass From 7e06b7239c2f2625aa8468cbc8a7ff3949fd5f5a Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Mon, 26 Dec 2022 17:55:36 +0100 Subject: [PATCH 27/98] WIP: Adding containers to ScheduleTreeScopes and more MapFission test. --- .../analysis/schedule_tree/sdfg_to_tree.py | 33 +- .../analysis/schedule_tree/transformations.py | 53 ++-- dace/sdfg/analysis/schedule_tree/treenodes.py | 112 +++++-- tests/sdfg/schedule_tree/conversion_test.py | 2 + .../schedule_tree/map_fission_test.py | 284 ++++++++++++++++++ 5 files changed, 434 insertions(+), 50 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index f39ac664ac..4b152e98ed 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -17,6 +17,19 @@ import sys +def populate_containers(scope: tn.ScheduleTreeScope, defined_arrays: Set[str] = None): + defined_arrays = defined_arrays or set() + if scope.top_level: + scope.containers = {name: copy.deepcopy(desc) for name, desc in scope.sdfg.arrays.items() if not desc.transient} + defined_arrays = set(scope.containers.keys()) + _, defined_arrays = scope.define_arrays(0, defined_arrays) + for child in scope.children: + child.parent = scope + if isinstance(child, tn.ScheduleTreeScope): + # _, defined_arrays = child.define_arrays(0, defined_arrays) + populate_containers(child, defined_arrays) + + def normalize_memlet(sdfg: SDFG, state: SDFGState, original: gr.MultiConnectorEdge[Memlet], data: str) -> Memlet: """ Normalizes a memlet to a given data descriptor. @@ -77,6 +90,10 @@ def remove_name_collisions(sdfg: SDFG): replacements: Dict[str, str] = {} parent_node = nsdfg.parent_nsdfg_node + # Preserve top-level SDFG names + if not parent_node: + continue + # Rename duplicate data containers for name, desc in nsdfg.arrays.items(): # TODO: Is it better to do this while parsing the SDFG? @@ -97,6 +114,7 @@ def remove_name_collisions(sdfg: SDFG): csdfg = parent_sdfg pnode = csdfg.parent_nsdfg_node cname = parent_name + # if pnode is None and not pdesc.transient and name != cname: if pnode is None and not pdesc.transient and name != cname: replacements[name] = cname name = cname @@ -240,7 +258,8 @@ def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[ else: target_name = innermost_node.data new_memlet = normalize_memlet(sdfg, state, e, outermost_node.data) - result[e] = tn.CopyNode(sdfg=sdfg, target=target_name, memlet=new_memlet) + result[e] = tn.CopyNode(target=target_name, memlet=new_memlet) + # result[e] = tn.CopyNode(sdfg=sdfg, target=target_name, memlet=new_memlet) return result @@ -320,11 +339,13 @@ def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: elif isinstance(node, dace.nodes.Tasklet): in_memlets = {e.dst_conn: e.data for e in state.in_edges(node) if e.dst_conn} out_memlets = {e.src_conn: e.data for e in state.out_edges(node) if e.src_conn} - result.append(tn.TaskletNode(sdfg=sdfg, node=node, in_memlets=in_memlets, out_memlets=out_memlets)) + result.append(tn.TaskletNode(node=node, in_memlets=in_memlets, out_memlets=out_memlets)) + # result.append(tn.TaskletNode(sdfg=sdfg, node=node, in_memlets=in_memlets, out_memlets=out_memlets)) elif isinstance(node, dace.nodes.LibraryNode): in_memlets = {e.dst_conn: e.data for e in state.in_edges(node) if e.dst_conn} out_memlets = {e.src_conn: e.data for e in state.out_edges(node) if e.src_conn} - result.append(tn.LibraryCall(sdfg=sdfg, node=node, in_memlets=in_memlets, out_memlets=out_memlets)) + result.append(tn.LibraryCall(node=node, in_memlets=in_memlets, out_memlets=out_memlets)) + # result.append(tn.LibraryCall(sdfg=sdfg, node=node, in_memlets=in_memlets, out_memlets=out_memlets)) elif isinstance(node, dace.nodes.AccessNode): # If one of the neighboring edges has a schedule tree node attached to it, use that for e in state.all_edges(node): @@ -447,13 +468,17 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche raise tn.UnsupportedScopeException(type(node).__name__) if node.first_state is not None: - result = [tn.StateLabel(sdfg=node.first_state.parent, state=node.first_state)] + result + result = [tn.StateLabel(state=node.first_state)] + result + # result = [tn.StateLabel(sdfg=node.first_state.parent, state=node.first_state)] + result return result # Recursive traversal of the control flow tree result = tn.ScheduleTreeScope(sdfg=sdfg, top_level=True, children=totree(cfg)) + if toplevel: + populate_containers(result) + # Clean up tree stpasses.remove_unused_and_duplicate_labels(result) diff --git a/dace/sdfg/analysis/schedule_tree/transformations.py b/dace/sdfg/analysis/schedule_tree/transformations.py index 3551b6bf7f..b0abebb41c 100644 --- a/dace/sdfg/analysis/schedule_tree/transformations.py +++ b/dace/sdfg/analysis/schedule_tree/transformations.py @@ -3,7 +3,7 @@ from dace import data as dt, Memlet, SDFG from dace.sdfg.analysis.schedule_tree import treenodes as tnodes from dace.sdfg.analysis.schedule_tree import utils as tutils -from typing import Dict +from typing import Dict, Set _dataflow_nodes = (tnodes.ViewNode, tnodes.RefSetNode, tnodes.CopyNode, tnodes.DynScopeCopyNode, tnodes.TaskletNode, tnodes.LibraryCall) @@ -16,11 +16,12 @@ def _update_memlets(data: Dict[str, dt.Data], memlets: Dict[str, Memlet], index: memlets[conn] = Memlet(data=memlet.data, subset=subset) -def _augment_data(data: Dict[str, dt.Data], map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode, sdfg: SDFG): +# def _augment_data(data: Dict[str, dt.Data], map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode, sdfg: SDFG): +def _augment_data(data: Set[str], map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode, sdfg: SDFG): # Generate map-related indices, sizes and, strides map = map_scope.node.map - index = ", ".join(map.params) + index = ", ".join(f"{p} - {r[0]}" for p, r in zip(map.params, map.range)) size = map.range.size() strides = [1] * len(size) for i in range(len(size) - 2, -1, -1): @@ -28,10 +29,13 @@ def _augment_data(data: Dict[str, dt.Data], map_scope: tnodes.MapScope, tree: tn # Augment data descriptors replace = dict() - for name, nsdfg in data.items(): - desc = nsdfg.arrays[name] + # for name, nsdfg in data.items(): + for name in data: + # desc = nsdfg.arrays[name] + desc = map_scope.containers[name] if isinstance(desc, dt.Scalar): - nsdfg.arrays[name] = dt.Array(desc.dtype, size, True, storage=desc.storage) + # nsdfg.arrays[name] = dt.Array(desc.dtype, size, True, storage=desc.storage) + desc = dt.Array(desc.dtype, size, True, storage=desc.storage) replace[name] = True else: mult = desc.shape[0] @@ -39,13 +43,15 @@ def _augment_data(data: Dict[str, dt.Data], map_scope: tnodes.MapScope, tree: tn new_strides = [s * mult for s in strides] desc.strides = (*new_strides, *desc.strides) replace[name] = False - if sdfg.parent: - nsdfg_node = nsdfg.parent_nsdfg_node - nsdfg_state = nsdfg.parent - nsdfg_node.out_connectors = {**nsdfg_node.out_connectors, name: None} - sdfg.arrays[name] = deepcopy(nsdfg.arrays[name]) - access = nsdfg_state.add_access(name) - nsdfg_state.add_edge(nsdfg_node, name, access, None, Memlet.from_array(name, sdfg.arrays[name])) + del map_scope.containers[name] + map_scope.parent.containers[name] = desc + # if sdfg.parent: + # nsdfg_node = nsdfg.parent_nsdfg_node + # nsdfg_state = nsdfg.parent + # nsdfg_node.out_connectors = {**nsdfg_node.out_connectors, name: None} + # sdfg.arrays[name] = deepcopy(nsdfg.arrays[name]) + # access = nsdfg_state.add_access(name) + # nsdfg_state.add_edge(nsdfg_node, name, access, None, Memlet.from_array(name, sdfg.arrays[name])) # Update memlets frontier = list(tree.children) @@ -56,8 +62,9 @@ def _augment_data(data: Dict[str, dt.Data], map_scope: tnodes.MapScope, tree: tn _update_memlets(data, node.in_memlets, index, replace) _update_memlets(data, node.out_memlets, index, replace) except AttributeError: - subset = index if replace[node.memlet.data] else f"{index}, {node.memlet.subset}" - node.memlet = Memlet(data=node.memlet.data, subset=subset) + if node.memlet.data in data: + subset = index if replace[node.memlet.data] else f"{index}, {node.memlet.subset}" + node.memlet = Memlet(data=node.memlet.data, subset=subset) if hasattr(node, 'children'): frontier.extend(node.children) @@ -105,18 +112,23 @@ def map_fission(map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode) -> bo # if any(p in e.data.free_symbols for p in map.params): # return False - data_to_augment = dict() + # data_to_augment = dict() + data_to_augment = set() frontier = list(partition) while len(frontier) > 0: scope = frontier.pop() if isinstance(scope, _dataflow_nodes): try: for _, memlet in scope.out_memlets.items(): - if scope.sdfg.arrays[memlet.data].transient: - data_to_augment[memlet.data] = scope.sdfg + if memlet.data in map_scope.containers: + data_to_augment.add(memlet.data) + # if scope.sdfg.arrays[memlet.data].transient: + # data_to_augment[memlet.data] = scope.sdfg except AttributeError: - if scope.target in scope.sdfg.arrays and scope.sdfg.arrays[scope.target].transient: - data_to_augment[scope.target] = scope.sdfg + if scope.target in map_scope.containers: + data_to_augment.add(scope.target) + # if scope.target in scope.sdfg.arrays and scope.sdfg.arrays[scope.target].transient: + # data_to_augment[scope.target] = scope.sdfg if hasattr(scope, 'children'): frontier.extend(scope.children) _augment_data(data_to_augment, map_scope, tree, sdfg) @@ -129,6 +141,7 @@ def map_fission(map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode) -> bo if not isinstance(child_scope, list): child_scope = [child_scope] scope = tnodes.MapScope(sdfg, False, child_scope, deepcopy(map_scope.node)) + scope.parent = parent_scope parent_scope.children.insert(idx, scope) return True diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index f456176745..c3a40bb843 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -1,4 +1,5 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +import copy from dataclasses import dataclass, field from dace import nodes, data, subsets from dace.codegen import control_flow as cf @@ -18,7 +19,8 @@ class UnsupportedScopeException(Exception): @dataclass class ScheduleTreeNode: - sdfg: SDFG + # sdfg: SDFG + # sdfg: Optional[SDFG] = field(default=None, init=False) parent: Optional['ScheduleTreeScope'] = field(default=None, init=False) def as_string(self, indent: int = 0): @@ -29,39 +31,59 @@ def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[s return string + indent * INDENTATION + 'UNSUPPORTED', defined_arrays def define_arrays(self, indent: int, defined_arrays: Set[str]) -> Tuple[str, Set[str]]: - defined_arrays = defined_arrays or set() - string = '' - undefined_arrays = {name: desc for name, desc in self.sdfg.arrays.items() if not name in defined_arrays and desc.transient} - if hasattr(self, 'children'): - times_used = {name: 0 for name in undefined_arrays} - for child in self.children: - for name in undefined_arrays: - if child.is_data_used(name): - times_used[name] += 1 - undefined_arrays = {name: desc for name, desc in undefined_arrays.items() if times_used[name] > 1} - for name, desc in undefined_arrays.items(): - string += indent * INDENTATION + f"{name} = numpy.ndarray({desc.shape}, {TYPECLASS_TO_STRING[desc.dtype].replace('::', '.')})\n" - defined_arrays |= undefined_arrays.keys() - return string, defined_arrays + return '', defined_arrays + # defined_arrays = defined_arrays or set() + # string = '' + # undefined_arrays = {name: desc for name, desc in self.sdfg.arrays.items() if not name in defined_arrays and desc.transient} + # if hasattr(self, 'children'): + # times_used = {name: 0 for name in undefined_arrays} + # for child in self.children: + # for name in undefined_arrays: + # if child.is_data_used(name): + # times_used[name] += 1 + # undefined_arrays = {name: desc for name, desc in undefined_arrays.items() if times_used[name] > 1} + # for name, desc in undefined_arrays.items(): + # string += indent * INDENTATION + f"{name} = numpy.ndarray({desc.shape}, {TYPECLASS_TO_STRING[desc.dtype].replace('::', '.')})\n" + # defined_arrays |= undefined_arrays.keys() + # return string, defined_arrays def is_data_used(self, name: str) -> bool: - for child in self.children: - if child.is_data_used(name): - return True - return False + pass + # for child in self.children: + # if child.is_data_used(name): + # return True + # return False @dataclass -class ScheduleTreeScope(ScheduleTreeNode): +class ScheduleTreeScope(ScheduleTreeNode): + sdfg: SDFG top_level: bool children: List['ScheduleTreeNode'] + containers: Optional[Dict[str, data.Data]] = field(default_factory=dict, init=False) + # def __init__(self, sdfg: Optional[SDFG] = None, top_level: Optional[bool] = False, children: Optional[List['ScheduleTreeNode']] = None): def __init__(self, sdfg: Optional[SDFG] = None, top_level: Optional[bool] = False, children: Optional[List['ScheduleTreeNode']] = None): self.sdfg = sdfg self.top_level = top_level self.children = children or [] - for child in children: - child.parent = self + # self.__post_init__() + # for child in children: + # child.parent = self + # _, defined_arrays = self.define_arrays(0, set()) + # self.containers = {name: copy.deepcopy(sdfg.arrays[name]) for name in defined_arrays} + # if top_level: + # self.containers.update({name: copy.deepcopy(desc) for name, desc in sdfg.arrays.items() if not desc.transient}) + # # self.containers = {name: copy.deepcopy(container) for name, container in sdfg.arrays.items()} + + # def __post_init__(self): + # for child in self.children: + # child.parent = self + # _, defined_arrays = self.define_arrays(0, set()) + # self.containers = {name: copy.deepcopy(self.sdfg.arrays[name]) for name in defined_arrays} + # if self.top_level: + # self.containers.update({name: copy.deepcopy(desc) for name, desc in self.sdfg.arrays.items() if not desc.transient}) + # # self.containers = {name: copy.deepcopy(container) for name, container in sdfg.arrays.items()} def as_string(self, indent: int = 0): return '\n'.join([child.as_string(indent + 1) for child in self.children]) @@ -70,22 +92,56 @@ def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[s if self.top_level: header = '' for s in self.sdfg.free_symbols: - header += f"{s} = dace.symbol('{s}', {TYPECLASS_TO_STRING[self.sdfg.symbols[s]].replace('::', '.')})" + header += f"{s} = dace.symbol('{s}', {TYPECLASS_TO_STRING[self.sdfg.symbols[s]].replace('::', '.')})\n" header += f""" @dace.program def {self.sdfg.label}({self.sdfg.python_signature()}): """ - defined_arrays = set([name for name, desc in self.sdfg.arrays.items() if not desc.transient]) + # defined_arrays = set([name for name, desc in self.sdfg.arrays.items() if not desc.transient]) + defined_arrays = set([name for name, desc in self.containers.items() if not desc.transient]) else: header = '' defined_arrays = defined_arrays or set() - string, defined_arrays = self.define_arrays(indent + 1, defined_arrays) + cindent = indent + 1 + # string, defined_arrays = self.define_arrays(indent + 1, defined_arrays) + string = '' + undefined_arrays = {name: desc for name, desc in self.containers.items() if name not in defined_arrays} + for name, desc in undefined_arrays.items(): + string += cindent * INDENTATION + f"{name} = numpy.ndarray({desc.shape}, {TYPECLASS_TO_STRING[desc.dtype].replace('::', '.')})\n" + defined_arrays |= undefined_arrays.keys() for child in self.children: substring, defined_arrays = child.as_python(indent + 1, defined_arrays) string += substring if string[-1] != '\n': string += '\n' return header + string, defined_arrays + + def define_arrays(self, indent: int, defined_arrays: Set[str]) -> Tuple[str, Set[str]]: + defined_arrays = defined_arrays or set() + string = '' + undefined_arrays = {} + for sdfg in self.sdfg.all_sdfgs_recursive(): + undefined_arrays.update({name: desc for name, desc in sdfg.arrays.items() if not name in defined_arrays and desc.transient}) + # undefined_arrays = {name: desc for name, desc in self.sdfg.arrays.items() if not name in defined_arrays and desc.transient} + times_used = {name: 0 for name in undefined_arrays} + for child in self.children: + for name in undefined_arrays: + if child.is_data_used(name): + times_used[name] += 1 + undefined_arrays = {name: desc for name, desc in undefined_arrays.items() if times_used[name] > 1} + if not self.containers: + self.containers = {} + for name, desc in undefined_arrays.items(): + string += indent * INDENTATION + f"{name} = numpy.ndarray({desc.shape}, {TYPECLASS_TO_STRING[desc.dtype].replace('::', '.')})\n" + self.containers[name] = copy.deepcopy(desc) + defined_arrays |= undefined_arrays.keys() + return string, defined_arrays + + def is_data_used(self, name: str) -> bool: + for child in self.children: + if child.is_data_used(name): + return True + return False # TODO: Get input/output memlets? @@ -309,7 +365,8 @@ def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[s out_memlets = ', '.join(f"'{k}': {v}" for k, v in self.out_memlets.items()) defined_arrays = defined_arrays or set() string, defined_arrays = self.define_arrays(indent, defined_arrays) - return string + indent * INDENTATION + f"dace.tree.tasklet(label='{self.node.label}', inputs={{{in_memlets}}}, outputs={{{out_memlets}}}, code='{self.node.code.as_string}', language=dace.{self.node.language})", defined_arrays + code = self.node.code.as_string.replace('\n', '\\n') + return string + indent * INDENTATION + f"dace.tree.tasklet(label='{self.node.label}', inputs={{{in_memlets}}}, outputs={{{out_memlets}}}, code='{code}', language=dace.{self.node.language})", defined_arrays def is_data_used(self, name: str) -> bool: used_data = set([memlet.data for memlet in self.in_memlets.values()]) @@ -380,6 +437,9 @@ def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[s string, defined_arrays = self.define_arrays(indent, defined_arrays) return string + indent * INDENTATION + f'dace.tree.copy(src={self.memlet.data}[{self.memlet.subset}], dst={self.target}{offset}, wcr={self.memlet.wcr})', defined_arrays + def is_data_used(self, name: str) -> bool: + return name is self.memlet.data or name is self.target + @dataclass class DynScopeCopyNode(ScheduleTreeNode): diff --git a/tests/sdfg/schedule_tree/conversion_test.py b/tests/sdfg/schedule_tree/conversion_test.py index 2bd1f227f2..abb1967fbf 100644 --- a/tests/sdfg/schedule_tree/conversion_test.py +++ b/tests/sdfg/schedule_tree/conversion_test.py @@ -24,6 +24,8 @@ def map_with_tasklet_and_library(A: dace.float32[N, 5, 5], B: dace.float32[N, 5, val0 = map_with_tasklet_and_library(A, B, cst) sdfg0 = map_with_tasklet_and_library.to_sdfg() tree = as_schedule_tree(sdfg0) + pcode, _ = tree.as_python() + print(pcode) sdfg1 = as_sdfg(tree) val1 = sdfg1(A=A, B=B, cst=cst, N=A.shape[0]) diff --git a/tests/transformations/schedule_tree/map_fission_test.py b/tests/transformations/schedule_tree/map_fission_test.py index addd0c2a3a..2f12c7721c 100644 --- a/tests/transformations/schedule_tree/map_fission_test.py +++ b/tests/transformations/schedule_tree/map_fission_test.py @@ -1,8 +1,11 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. import dace import numpy as np +from dace import nodes from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree, as_sdfg from dace.sdfg.analysis.schedule_tree.transformations import map_fission +from dace.transformation.helpers import nest_state_subgraph +from tests.transformations.mapfission_test import mapfission_sdfg def test_map_with_tasklet_and_library(): @@ -33,5 +36,286 @@ def map_with_tasklet_and_library(A: dace.float32[N, 5, 5], B: dace.float32[N, 5, assert np.allclose(val1, ref) +def test_subgraph(): + + rng = np.random.default_rng(42) + A = rng.random((4, )) + ref = np.zeros([2], dtype=np.float64) + ref[0] = (A[0] + A[1]) + (A[0] * 2 * A[1] * 2) + (A[0] * 3) + 5.0 + ref[1] = (A[2] + A[3]) + (A[2] * 2 * A[3] * 2) + (A[2] * 3) + 5.0 + val = np.empty((2, )) + + sdfg0 = mapfission_sdfg() + tree = as_schedule_tree(sdfg0) + pcode, _ = tree.as_python() + print(pcode) + result = map_fission(tree.children[0], tree) + assert result + pcode, _ = tree.as_python() + print(pcode) + sdfg1 = as_sdfg(tree) + sdfg1(A=A, B=val) + + assert np.allclose(val, ref) + + +def test_nested_subgraph(): + + rng = np.random.default_rng(42) + A = rng.random((4, )) + ref = np.zeros([2], dtype=np.float64) + ref[0] = (A[0] + A[1]) + (A[0] * 2 * A[1] * 2) + (A[0] * 3) + 5.0 + ref[1] = (A[2] + A[3]) + (A[2] * 2 * A[3] * 2) + (A[2] * 3) + 5.0 + val = np.empty((2, )) + + sdfg0 = mapfission_sdfg() + state = sdfg0.nodes()[0] + topmap = next(node for node in state.nodes() if isinstance(node, nodes.MapEntry) and node.label == 'outer') + subgraph = state.scope_subgraph(topmap, include_entry=False, include_exit=False) + nest_state_subgraph(sdfg0, state, subgraph) + tree = as_schedule_tree(sdfg0) + result = map_fission(tree.children[0], tree) + assert result + pcode, _ = tree.as_python() + print(pcode) + sdfg1 = as_sdfg(tree) + sdfg1(A=A, B=val) + + assert np.allclose(val, ref) + + +def test_nested_transient(): + """ Test nested SDFGs with transients. """ + + # Inner SDFG + nsdfg = dace.SDFG('nested') + nsdfg.add_array('a', [1], dace.float64) + nsdfg.add_array('b', [1], dace.float64) + nsdfg.add_transient('t', [1], dace.float64) + + # a->t state + nstate = nsdfg.add_state() + irnode = nstate.add_read('a') + task = nstate.add_tasklet('t1', {'inp'}, {'out'}, 'out = 2*inp') + iwnode = nstate.add_write('t') + nstate.add_edge(irnode, None, task, 'inp', dace.Memlet.simple('a', '0')) + nstate.add_edge(task, 'out', iwnode, None, dace.Memlet.simple('t', '0')) + + # t->a state + first_state = nstate + nstate = nsdfg.add_state() + irnode = nstate.add_read('t') + task = nstate.add_tasklet('t2', {'inp'}, {'out'}, 'out = 3*inp') + iwnode = nstate.add_write('b') + nstate.add_edge(irnode, None, task, 'inp', dace.Memlet.simple('t', '0')) + nstate.add_edge(task, 'out', iwnode, None, dace.Memlet.simple('b', '0')) + + nsdfg.add_edge(first_state, nstate, dace.InterstateEdge()) + + # Outer SDFG + sdfg = dace.SDFG('nested_transient_fission') + sdfg.add_array('A', [2], dace.float64) + state = sdfg.add_state() + rnode = state.add_read('A') + wnode = state.add_write('A') + me, mx = state.add_map('outer', dict(i='0:2')) + nsdfg_node = state.add_nested_sdfg(nsdfg, None, {'a'}, {'b'}) + state.add_memlet_path(rnode, me, nsdfg_node, dst_conn='a', memlet=dace.Memlet.simple('A', 'i')) + state.add_memlet_path(nsdfg_node, mx, wnode, src_conn='b', memlet=dace.Memlet.simple('A', 'i')) + + # self.assertGreater(sdfg.apply_transformations_repeated(MapFission), 0) + tree = as_schedule_tree(sdfg) + result = map_fission(tree.children[0], tree) + assert result + sdfg = as_sdfg(tree) + + # Test + A = np.random.rand(2) + expected = A * 6 + sdfg(A=A) + # self.assertTrue(np.allclose(A, expected)) + assert np.allclose(A, expected) + + +def test_inputs_outputs(): + """ + Test subgraphs where the computation modules that are in the middle + connect to the outside. + """ + + sdfg = dace.SDFG('inputs_outputs_fission') + sdfg.add_array('in1', [2], dace.float64) + sdfg.add_array('in2', [2], dace.float64) + sdfg.add_scalar('tmp', dace.float64, transient=True) + sdfg.add_array('out1', [2], dace.float64) + sdfg.add_array('out2', [2], dace.float64) + state = sdfg.add_state() + in1 = state.add_read('in1') + in2 = state.add_read('in2') + out1 = state.add_write('out1') + out2 = state.add_write('out2') + me, mx = state.add_map('outer', dict(i='0:2')) + t1 = state.add_tasklet('t1', {'i1'}, {'o1', 'o2'}, 'o1 = i1 * 2; o2 = i1 * 5') + t2 = state.add_tasklet('t2', {'i1', 'i2'}, {'o1'}, 'o1 = i1 * i2') + state.add_memlet_path(in1, me, t1, dst_conn='i1', memlet=dace.Memlet.simple('in1', 'i')) + state.add_memlet_path(in2, me, t2, dst_conn='i2', memlet=dace.Memlet.simple('in2', 'i')) + state.add_edge(t1, 'o1', t2, 'i1', dace.Memlet.simple('tmp', '0')) + state.add_memlet_path(t2, mx, out1, src_conn='o1', memlet=dace.Memlet.simple('out1', 'i')) + state.add_memlet_path(t1, mx, out2, src_conn='o2', memlet=dace.Memlet.simple('out2', 'i')) + + # self.assertGreater(sdfg.apply_transformations(MapFission), 0) + tree = as_schedule_tree(sdfg) + result = map_fission(tree.children[0], tree) + assert result + sdfg = as_sdfg(tree) + + # Test + A, B, C, D = tuple(np.random.rand(2) for _ in range(4)) + expected_C = (A * 2) * B + expected_D = A * 5 + sdfg(in1=A, in2=B, out1=C, out2=D) + # self.assertTrue(np.allclose(C, expected_C)) + # self.assertTrue(np.allclose(D, expected_D)) + assert np.allclose(C, expected_C) + assert np.allclose(D, expected_D) + + +def test_offsets(): + sdfg = dace.SDFG('mapfission_offsets') + sdfg.add_array('A', [20], dace.float64) + sdfg.add_scalar('interim', dace.float64, transient=True) + state = sdfg.add_state() + me, mx = state.add_map('outer', dict(i='10:20')) + + t1 = state.add_tasklet('addone', {'a'}, {'b'}, 'b = a + 1') + t2 = state.add_tasklet('addtwo', {'a'}, {'b'}, 'b = a + 2') + + aread = state.add_read('A') + awrite = state.add_write('A') + state.add_memlet_path(aread, me, t1, dst_conn='a', memlet=dace.Memlet.simple('A', 'i')) + state.add_edge(t1, 'b', t2, 'a', dace.Memlet.simple('interim', '0')) + state.add_memlet_path(t2, mx, awrite, src_conn='b', memlet=dace.Memlet.simple('A', 'i')) + + # self.assertGreater(sdfg.apply_transformations(MapFission), 0) + tree = as_schedule_tree(sdfg) + pcode, _ = tree.as_python() + print(pcode) + result = map_fission(tree.children[0], tree) + assert result + pcode, _ = tree.as_python() + print(pcode) + sdfg = as_sdfg(tree) + + # dace.propagate_memlets_sdfg(sdfg) + # sdfg.validate() + + # Test + A = np.random.rand(20) + expected = A.copy() + expected[10:] += 3 + sdfg(A=A) + # self.assertTrue(np.allclose(A, expected)) + assert np.allclose(A, expected) + + +def test_offsets_array(): + sdfg = dace.SDFG('mapfission_offsets2') + sdfg.add_array('A', [20], dace.float64) + sdfg.add_array('interim', [1], dace.float64, transient=True) + state = sdfg.add_state() + me, mx = state.add_map('outer', dict(i='10:20')) + + t1 = state.add_tasklet('addone', {'a'}, {'b'}, 'b = a + 1') + interim = state.add_access('interim') + t2 = state.add_tasklet('addtwo', {'a'}, {'b'}, 'b = a + 2') + + aread = state.add_read('A') + awrite = state.add_write('A') + state.add_memlet_path(aread, me, t1, dst_conn='a', memlet=dace.Memlet.simple('A', 'i')) + state.add_edge(t1, 'b', interim, None, dace.Memlet.simple('interim', '0')) + state.add_edge(interim, None, t2, 'a', dace.Memlet.simple('interim', '0')) + state.add_memlet_path(t2, mx, awrite, src_conn='b', memlet=dace.Memlet.simple('A', 'i')) + + # self.assertGreater(sdfg.apply_transformations(MapFission), 0) + tree = as_schedule_tree(sdfg) + result = map_fission(tree.children[0], tree) + assert result + sdfg = as_sdfg(tree) + + # dace.propagate_memlets_sdfg(sdfg) + # sdfg.validate() + + # Test + A = np.random.rand(20) + expected = A.copy() + expected[10:] += 3 + sdfg(A=A) + # self.assertTrue(np.allclose(A, expected)) + assert np.allclose(A, expected) + + +def test_mapfission_with_symbols(): + ''' + Tests MapFission in the case of a Map containing a NestedSDFG that is using some symbol from the top-level SDFG + missing from the NestedSDFG's symbol mapping. Please note that this is an unusual case that is difficult to + reproduce and ultimately unrelated to MapFission. Consider solving the underlying issue and then deleting this + test and the corresponding (obsolete) code in MapFission. + ''' + + M, N = dace.symbol('M'), dace.symbol('N') + + sdfg = dace.SDFG('tasklet_code_with_symbols') + sdfg.add_array('A', (M, N), dace.int32) + sdfg.add_array('B', (M, N), dace.int32) + + state = sdfg.add_state('parent', is_start_state=True) + me, mx = state.add_map('parent_map', {'i': '0:N'}) + + nsdfg = dace.SDFG('nested_sdfg') + nsdfg.add_scalar('inner_A', dace.int32) + nsdfg.add_scalar('inner_B', dace.int32) + + nstate = nsdfg.add_state('child', is_start_state=True) + na = nstate.add_access('inner_A') + nb = nstate.add_access('inner_B') + ta = nstate.add_tasklet('tasklet_A', {}, {'__out'}, '__out = M') + tb = nstate.add_tasklet('tasklet_B', {}, {'__out'}, '__out = M') + nstate.add_edge(ta, '__out', na, None, dace.Memlet.from_array('inner_A', nsdfg.arrays['inner_A'])) + nstate.add_edge(tb, '__out', nb, None, dace.Memlet.from_array('inner_B', nsdfg.arrays['inner_B'])) + + a = state.add_access('A') + b = state.add_access('B') + t = nodes.NestedSDFG('child_sdfg', nsdfg, {}, {'inner_A', 'inner_B'}, {}) + nsdfg.parent = state + nsdfg.parent_sdfg = sdfg + nsdfg.parent_nsdfg_node = t + state.add_node(t) + state.add_nedge(me, t, dace.Memlet()) + state.add_memlet_path(t, mx, a, memlet=dace.Memlet('A[0, i]'), src_conn='inner_A') + state.add_memlet_path(t, mx, b, memlet=dace.Memlet('B[0, i]'), src_conn='inner_B') + + # num = sdfg.apply_transformations_repeated(MapFission) + tree = as_schedule_tree(sdfg) + result = map_fission(tree.children[0], tree) + assert result + sdfg = as_sdfg(tree) + + A = np.ndarray((2, 10), dtype=np.int32) + B = np.ndarray((2, 10), dtype=np.int32) + sdfg(A=A, B=B, M=2, N=10) + + ref = np.full((10, ), fill_value=2, dtype=np.int32) + + assert np.array_equal(A[0], ref) + assert np.array_equal(B[0], ref) + + if __name__ == "__main__": test_map_with_tasklet_and_library() + test_subgraph() + test_nested_subgraph() + test_nested_transient() + test_inputs_outputs() + test_offsets() + test_offsets_array() + test_mapfission_with_symbols() From 1e57142b1824489ce10cb428c33bd1c748bc3943 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Fri, 6 Jan 2023 15:52:04 +0100 Subject: [PATCH 28/98] WIP: ScheduleTree updates. --- .../analysis/schedule_tree/sdfg_to_tree.py | 40 +++++++---- dace/sdfg/analysis/schedule_tree/treenodes.py | 70 ++++++++++++++----- tests/sdfg/schedule_tree/conversion_test.py | 40 ++++++++++- 3 files changed, 117 insertions(+), 33 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 4b152e98ed..87c1ac5bec 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -69,6 +69,9 @@ def replace_memlets(sdfg: SDFG, array_mapping: Dict[str, Memlet]): for e in state.edges(): if e.data.data in array_mapping: e.data = unsqueeze_memlet(e.data, array_mapping[e.data.data]) + # for e in sdfg.edges(): + # for k, v in e.data.assignments.items(): + def remove_name_collisions(sdfg: SDFG): @@ -115,7 +118,7 @@ def remove_name_collisions(sdfg: SDFG): pnode = csdfg.parent_nsdfg_node cname = parent_name # if pnode is None and not pdesc.transient and name != cname: - if pnode is None and not pdesc.transient and name != cname: + if pnode is None and name != cname: replacements[name] = cname name = cname continue @@ -153,8 +156,8 @@ def remove_name_collisions(sdfg: SDFG): # Replacing connector names # Replacing edge connector names if nsdfg.parent_sdfg: - nsdfg.parent_nsdfg_node.in_connectors = {replacements[c]: t for c, t in nsdfg.parent_nsdfg_node.in_connectors.items()} - nsdfg.parent_nsdfg_node.out_connectors = {replacements[c]: t for c, t in nsdfg.parent_nsdfg_node.out_connectors.items()} + nsdfg.parent_nsdfg_node.in_connectors = {replacements[c] if c in replacements else c: t for c, t in nsdfg.parent_nsdfg_node.in_connectors.items()} + nsdfg.parent_nsdfg_node.out_connectors = {replacements[c] if c in replacements else c: t for c, t in nsdfg.parent_nsdfg_node.out_connectors.items()} for e in nsdfg.parent.all_edges(nsdfg.parent_nsdfg_node): if e.src_conn in replacements: e._src_conn = replacements[e.src_conn] @@ -342,8 +345,15 @@ def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: result.append(tn.TaskletNode(node=node, in_memlets=in_memlets, out_memlets=out_memlets)) # result.append(tn.TaskletNode(sdfg=sdfg, node=node, in_memlets=in_memlets, out_memlets=out_memlets)) elif isinstance(node, dace.nodes.LibraryNode): - in_memlets = {e.dst_conn: e.data for e in state.in_edges(node) if e.dst_conn} - out_memlets = {e.src_conn: e.data for e in state.out_edges(node) if e.src_conn} + # NOTE: LibraryNodes do not necessarily have connectors + if node.in_connectors: + in_memlets = {e.dst_conn: e.data for e in state.in_edges(node) if e.dst_conn} + else: + in_memlets = set([e.data for e in state.in_edges(node)]) + if node.out_connectors: + out_memlets = {e.src_conn: e.data for e in state.out_edges(node) if e.src_conn} + else: + out_memlets = set([e.data for e in state.out_edges(node)]) result.append(tn.LibraryCall(node=node, in_memlets=in_memlets, out_memlets=out_memlets)) # result.append(tn.LibraryCall(sdfg=sdfg, node=node, in_memlets=in_memlets, out_memlets=out_memlets)) elif isinstance(node, dace.nodes.AccessNode): @@ -437,32 +447,32 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche if sdfg.out_degree(node.state) == 1 and parent.sequential: # Conditional state in sequential block! Add "if not condition goto exit" result.append( - tn.StateIfScope(condition=CodeBlock(negate_expr(e.data.condition)), + tn.StateIfScope(sdfg=sdfg, top_level=False, condition=CodeBlock(negate_expr(e.data.condition)), children=[tn.GotoNode(target=None)])) result.extend(edge_body) else: # Add "if condition" with the body above - result.append(tn.StateIfScope(condition=e.data.condition, children=edge_body)) + result.append(tn.StateIfScope(sdfg=sdfg, top_level=False, condition=e.data.condition, children=edge_body)) else: result.extend(edge_body) elif isinstance(node, cf.ForScope): - result.append(tn.ForScope(header=node, children=totree(node.body))) + result.append(tn.ForScope(sdfg=sdfg, top_level=False, header=node, children=totree(node.body))) elif isinstance(node, cf.IfScope): - result.append(tn.IfScope(condition=node.condition, children=totree(node.body))) + result.append(tn.IfScope(sdfg=sdfg, top_level=False, condition=node.condition, children=totree(node.body))) if node.orelse is not None: - result.append(tn.ElseScope(children=totree(node.orelse))) + result.append(tn.ElseScope(sdfg=sdfg, top_level=False, children=totree(node.orelse))) elif isinstance(node, cf.IfElseChain): # Add "if" for the first condition, "elif"s for the rest - result.append(tn.IfScope(condition=node.body[0][0], children=totree(node.body[0][1]))) + result.append(tn.IfScope(sdfg=sdfg, top_level=False, condition=node.body[0][0], children=totree(node.body[0][1]))) for cond, body in node.body[1:]: - result.append(tn.ElifScope(condition=cond, children=totree(body))) + result.append(tn.ElifScope(sdfg=sdfg, top_level=False, condition=cond, children=totree(body))) # "else goto exit" - result.append(tn.ElseScope(children=[tn.GotoNode(target=None)])) + result.append(tn.ElseScope(sdfg=sdfg, top_level=False, children=[tn.GotoNode(target=None)])) elif isinstance(node, cf.WhileScope): - result.append(tn.WhileScope(header=node, children=totree(node.body))) + result.append(tn.WhileScope(sdfg=sdfg, top_level=False, header=node, children=totree(node.body))) elif isinstance(node, cf.DoWhileScope): - result.append(tn.DoWhileScope(header=node, children=totree(node.body))) + result.append(tn.DoWhileScope(sdfg=sdfg, top_level=False, header=node, children=totree(node.body))) else: # e.g., "SwitchCaseScope" raise tn.UnsupportedScopeException(type(node).__name__) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index c3a40bb843..d9fb876968 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -8,7 +8,7 @@ from dace.sdfg import SDFG from dace.sdfg.state import SDFGState from dace.memlet import Memlet -from typing import Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple, Union INDENTATION = ' ' @@ -88,7 +88,7 @@ def __init__(self, sdfg: Optional[SDFG] = None, top_level: Optional[bool] = Fals def as_string(self, indent: int = 0): return '\n'.join([child.as_string(indent + 1) for child in self.children]) - def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: + def as_python(self, indent: int = 0, defined_arrays: Set[str] = None, def_offset: int = 1, sep_defs: bool = False) -> Tuple[str, Set[str]]: if self.top_level: header = '' for s in self.sdfg.free_symbols: @@ -102,19 +102,23 @@ def {self.sdfg.label}({self.sdfg.python_signature()}): else: header = '' defined_arrays = defined_arrays or set() - cindent = indent + 1 + cindent = indent + def_offset # string, defined_arrays = self.define_arrays(indent + 1, defined_arrays) - string = '' + definitions = '' + body = '' undefined_arrays = {name: desc for name, desc in self.containers.items() if name not in defined_arrays} for name, desc in undefined_arrays.items(): - string += cindent * INDENTATION + f"{name} = numpy.ndarray({desc.shape}, {TYPECLASS_TO_STRING[desc.dtype].replace('::', '.')})\n" + definitions += cindent * INDENTATION + f"{name} = numpy.ndarray({desc.shape}, {TYPECLASS_TO_STRING[desc.dtype].replace('::', '.')})\n" defined_arrays |= undefined_arrays.keys() for child in self.children: substring, defined_arrays = child.as_python(indent + 1, defined_arrays) - string += substring - if string[-1] != '\n': - string += '\n' - return header + string, defined_arrays + body += substring + if body[-1] != '\n': + body += '\n' + if sep_defs: + return definitions, body, defined_arrays + else: + return header + definitions + body, defined_arrays def define_arrays(self, indent: int, defined_arrays: Set[str]) -> Tuple[str, Set[str]]: defined_arrays = defined_arrays or set() @@ -213,6 +217,15 @@ def as_string(self, indent: int = 0): result = (indent * INDENTATION + f'for {node.itervar} = {node.init}; {node.condition.as_string}; ' f'{node.itervar} = {node.update}:\n') return result + super().as_string(indent) + + def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: + node = self.header + result = indent * INDENTATION + f'{node.itervar} = {node.init}\n' + result += indent * INDENTATION + f'while {node.condition.as_string}:\n' + defs, body, defined_arrays = super().as_python(indent, defined_arrays, def_offset=0, sep_defs=True) + result = defs + result + body + result += (indent + 1) * INDENTATION + f'{node.itervar} = {node.update}\n' + return result, defined_arrays @dataclass @@ -250,6 +263,11 @@ class IfScope(ControlFlowScope): def as_string(self, indent: int = 0): result = indent * INDENTATION + f'if {self.condition.as_string}:\n' return result + super().as_string(indent) + + def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: + result = indent * INDENTATION + f'if {self.condition.as_string}:\n' + string, defined_arrays = super().as_python(indent, defined_arrays) + return result + string, defined_arrays @dataclass @@ -377,12 +395,18 @@ def is_data_used(self, name: str) -> bool: @dataclass class LibraryCall(ScheduleTreeNode): node: nodes.LibraryNode - in_memlets: Dict[str, Memlet] - out_memlets: Dict[str, Memlet] + in_memlets: Union[Dict[str, Memlet], Set[Memlet]] + out_memlets: Union[Dict[str, Memlet], Set[Memlet]] def as_string(self, indent: int = 0): - in_memlets = ', '.join(f'{v}' for v in self.in_memlets.values()) - out_memlets = ', '.join(f'{v}' for v in self.out_memlets.values()) + if isinstance(self.in_memlets, set): + in_memlets = ', '.join(f'{v}' for v in self.in_memlets) + else: + in_memlets = ', '.join(f'{v}' for v in self.in_memlets.values()) + if isinstance(self.out_memlets, set): + out_memlets = ', '.join(f'{v}' for v in self.out_memlets) + else: + out_memlets = ', '.join(f'{v}' for v in self.out_memlets.values()) libname = type(self.node).__name__ # Get the properties of the library node without its superclasses own_properties = ', '.join(f'{k}={getattr(self.node, k)}' for k, v in self.node.__properties__.items() @@ -390,8 +414,14 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'{out_memlets} = library {libname}[{own_properties}]({in_memlets})' def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: - in_memlets = ', '.join(f"'{k}': {v}" for k, v in self.in_memlets.items()) - out_memlets = ', '.join(f"'{k}': {v}" for k, v in self.out_memlets.items()) + if isinstance(self.in_memlets, set): + in_memlets = ', '.join(f'{v}' for v in self.in_memlets) + else: + in_memlets = ', '.join(f"'{k}': {v}" for k, v in self.in_memlets.items()) + if isinstance(self.out_memlets, set): + out_memlets = ', '.join(f'{v}' for v in self.out_memlets) + else: + out_memlets = ', '.join(f"'{k}': {v}" for k, v in self.out_memlets.items()) libname = type(self.node).__module__ + '.' + type(self.node).__qualname__ # Get the properties of the library node without its superclasses own_properties = ', '.join(f'{k}={getattr(self.node, k)}' for k, v in self.node.__properties__.items() @@ -401,8 +431,14 @@ def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[s return string + indent * INDENTATION + f"dace.tree.library(ltype={libname}, label='{self.node.label}', inputs={{{in_memlets}}}, outputs={{{out_memlets}}}, {own_properties})", defined_arrays def is_data_used(self, name: str) -> bool: - used_data = set([memlet.data for memlet in self.in_memlets.values()]) - used_data |= set([memlet.data for memlet in self.out_memlets.values()]) + if isinstance(self.in_memlets, set): + used_data = set([memlet.data for memlet in self.in_memlets]) + else: + used_data = set([memlet.data for memlet in self.in_memlets.values()]) + if isinstance(self.out_memlets, set): + used_data |= set([memlet.data for memlet in self.out_memlets]) + else: + used_data |= set([memlet.data for memlet in self.out_memlets.values()]) return name in used_data diff --git a/tests/sdfg/schedule_tree/conversion_test.py b/tests/sdfg/schedule_tree/conversion_test.py index abb1967fbf..602644f172 100644 --- a/tests/sdfg/schedule_tree/conversion_test.py +++ b/tests/sdfg/schedule_tree/conversion_test.py @@ -33,5 +33,43 @@ def map_with_tasklet_and_library(A: dace.float32[N, 5, 5], B: dace.float32[N, 5, assert np.allclose(val1, ref) +def test_azimint_naive(): + + N, npt = (dace.symbol(s) for s in ('N', 'npt')) + @dace.program + def dace_azimint_naive(data: dace.float64[N], radius: dace.float64[N]): + rmax = np.amax(radius) + res = np.zeros((npt, ), dtype=np.float64) + for i in range(npt): + # for i in dace.map[0:npt]: + r1 = rmax * i / npt + r2 = rmax * (i + 1) / npt + mask_r12 = np.logical_and((r1 <= radius), (radius < r2)) + on_values = 0 + tmp = np.float64(0) + for j in dace.map[0:N]: + if mask_r12[j]: + tmp += data[j] + on_values += 1 + res[i] = tmp / on_values + return res + + rng = np.random.default_rng(42) + SN, Snpt = 1000, 10 + data, radius = rng.random((SN, )), rng.random((SN, )) + # ref = dace_azimint_naive(data, radius, npt=Snpt) + + sdfg0 = dace_azimint_naive.to_sdfg() + tree = as_schedule_tree(sdfg0) + print(tree.as_string()) + pcode, _ = tree.as_python() + print(pcode) + sdfg1 = as_sdfg(tree) + val = sdfg1(data=data, radius=radius, N=SN, npt=Snpt) + + assert np.allclose(val, ref) + + if __name__ == "__main__": - test_map_with_tasklet_and_library() + # test_map_with_tasklet_and_library() + test_azimint_naive() From 2754510296abbd023d5e91598f754a2919294a4d Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Tue, 10 Jan 2023 16:26:35 +0100 Subject: [PATCH 29/98] Experimenting with new passes. --- dace/sdfg/analysis/schedule_tree/passes.py | 67 ++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/dace/sdfg/analysis/schedule_tree/passes.py b/dace/sdfg/analysis/schedule_tree/passes.py index 52a58adc32..279e6c65c9 100644 --- a/dace/sdfg/analysis/schedule_tree/passes.py +++ b/dace/sdfg/analysis/schedule_tree/passes.py @@ -3,6 +3,7 @@ Assortment of passes for schedule trees. """ +from dace import data as dt, Memlet, subsets as sbs, symbolic as sym from dace.sdfg.analysis.schedule_tree import treenodes as tn from typing import Set @@ -61,3 +62,69 @@ def visit(self, node: tn.ScheduleTreeNode): return self.generic_visit(node) return RemoveEmptyScopes().visit(stree) + + +def wcr_to_reduce(stree: tn.ScheduleTreeScope): + """ + Converts WCR assignments to reductions. + + :param stree: The schedule tree to remove WCR assignments from. + """ + + class WCRToReduce(tn.ScheduleNodeTransformer): + + def visit(self, node: tn.ScheduleTreeNode): + + if isinstance(node, tn.TaskletNode): + + wcr_found = False + for _, memlet in node.out_memlets.items(): + if memlet.wcr: + wcr_found = True + break + + if wcr_found: + + loop_found = False + rng = None + idx = None + parent = node.parent + while parent: + if isinstance(parent, (tn.MapScope, tn.ForScope)): + loop_found = True + rng = parent.node.map.range + break + parent = parent.parent + + if loop_found: + + for conn, memlet in node.out_memlets.items(): + if memlet.wcr: + + scope = node.parent + while memlet.data not in scope.containers: + scope = scope.parent + desc = scope.containers[memlet.data] + + shape = rng.size() + list(desc.shape) if not isinstance(desc, dt.Scalar) else rng.size() + parent.containers[f'{memlet.data}_arr'] = dt.Array(desc.dtype, shape, transient=True) + + indices = [(sym.pystr_to_symbolic(s), sym.pystr_to_symbolic(s), 1) for s in parent.node.map.params] + if not isinstance(desc, dt.Scalar): + indices.extend(memlet.subset.ranges) + memlet.subset = sbs.Range(indices) + + from dace.libraries.standard import Reduce + rednode = Reduce(memlet.wcr) + libcall = tn.LibraryCall(rednode, {Memlet.from_array(f'{memlet.data}_arr', parent.containers[f'{memlet.data}_arr'])}, {Memlet.from_array(memlet.data, desc)}) + + memlet.data = f'{memlet.data}_arr' + memlet.wcr = None + + parent.children.append(libcall) + + + return self.generic_visit(node) + + return WCRToReduce().visit(stree) + From 4522a9ebc8a9734ae62503adbf93f0a0f403f77a Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Tue, 10 Jan 2023 16:28:22 +0100 Subject: [PATCH 30/98] Added if-fission transformation. --- .../analysis/schedule_tree/sdfg_to_tree.py | 7 ++-- .../analysis/schedule_tree/transformations.py | 32 +++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 87c1ac5bec..5eff1d89b0 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -69,8 +69,11 @@ def replace_memlets(sdfg: SDFG, array_mapping: Dict[str, Memlet]): for e in state.edges(): if e.data.data in array_mapping: e.data = unsqueeze_memlet(e.data, array_mapping[e.data.data]) - # for e in sdfg.edges(): - # for k, v in e.data.assignments.items(): + for e in sdfg.edges(): + syms = e.data.read_symbols() + for s in syms: + if s in array_mapping: + e.data.replace(s, str(array_mapping[s])) diff --git a/dace/sdfg/analysis/schedule_tree/transformations.py b/dace/sdfg/analysis/schedule_tree/transformations.py index b0abebb41c..1dd34b23dd 100644 --- a/dace/sdfg/analysis/schedule_tree/transformations.py +++ b/dace/sdfg/analysis/schedule_tree/transformations.py @@ -91,6 +91,7 @@ def map_fission(map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode) -> bo # State-scope check: if the body consists of a single state-scope, certain conditions apply. partition = tutils.partition_scope_body(map_scope) if len(partition) == 1: + child = partition[0] conditions = [] if isinstance(child, list): @@ -102,10 +103,12 @@ def map_fission(map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode) -> bo conditions.append(child.header.condition) elif isinstance(child, tnodes.WhileScope): conditions.append(child.header.test) + for cond in conditions: map = map_scope.node.map if any(p in cond.get_free_symbols() for p in map.params): return False + # TODO: How to run the check below in the ScheduleTree? # for s in cond.get_free_symbols(): # for e in graph.edges_by_connector(self.nested_sdfg, s): @@ -145,3 +148,32 @@ def map_fission(map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode) -> bo parent_scope.children.insert(idx, scope) return True + + +def if_fission(if_scope: tnodes.IfScope, tree: tnodes.ScheduleTreeNode) -> bool: + + parent_scope = if_scope.parent + idx = parent_scope.children.index(if_scope) + if len(parent_scope.children) > idx + 1 and isinstance(parent_scope.children[idx+1], + (tnodes.ElifScope, tnodes.ElseScope)): + return False + + partition = tutils.partition_scope_body(if_scope) + if len(partition) < 2: + return False + + parent_scope.children.pop(idx) + while len(partition) > 0: + child_scope = partition.pop() + if not isinstance(child_scope, list): + child_scope = [child_scope] + scope = tnodes.IfScope(if_scope.sdfg, False, child_scope, deepcopy(if_scope.condition)) + scope.parent = parent_scope + parent_scope.children.insert(idx, scope) + + return True + + +def wcr_to_reduce(map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode) -> bool: + + pass \ No newline at end of file From 2c49e14769f09ebd5ba0d21827a526b699af9930 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Thu, 12 Jan 2023 16:55:51 +0100 Subject: [PATCH 31/98] Experimenting with if canonicalization. --- dace/sdfg/analysis/schedule_tree/passes.py | 58 +++++++++++++++++++ .../analysis/schedule_tree/transformations.py | 20 ++++--- 2 files changed, 71 insertions(+), 7 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/passes.py b/dace/sdfg/analysis/schedule_tree/passes.py index 279e6c65c9..f7ee030035 100644 --- a/dace/sdfg/analysis/schedule_tree/passes.py +++ b/dace/sdfg/analysis/schedule_tree/passes.py @@ -3,6 +3,7 @@ Assortment of passes for schedule trees. """ +import copy from dace import data as dt, Memlet, subsets as sbs, symbolic as sym from dace.sdfg.analysis.schedule_tree import treenodes as tn from typing import Set @@ -128,3 +129,60 @@ def visit(self, node: tn.ScheduleTreeNode): return WCRToReduce().visit(stree) + +def canonicalize_if(tree: tn.ScheduleTreeScope): + """ + Canonicalizes sequences of if-elif-else scopes to sequences of if scopes. + """ + + from dace.sdfg.nodes import CodeBlock + from dace.sdfg.analysis.schedule_tree.transformations import if_fission + + class CanonicalizeIf(tn.ScheduleNodeTransformer): + + + def visit(self, node: tn.ScheduleTreeNode): + if not isinstance(node, (tn.ElifScope, tn.ElseScope)): + return super().visit(node) + + parent = node.parent + node_idx = parent.children.index(node) + + conditions = [] + for curr_node in reversed(parent.children[:node_idx]): + conditions.append(curr_node.condition) + if isinstance(curr_node, tn.IfScope): + break + condition = f"not ({' or '.join([f'({c.as_string})' for c in conditions])})" + if isinstance(node, tn.ElifScope): + condition = f"{condition} and {node.condition.as_string}" + new_node = tn.IfScope(node.sdfg, node.top_level, node.children, CodeBlock(condition)) + new_node.parent = node.parent + + return self.generic_visit(new_node) + + CanonicalizeIf().visit(tree) + print(tree.as_string()) + print() + + frontier = list(tree.children) + while frontier: + node = frontier.pop() + if isinstance(node, tn.IfScope): + parent = node.parent + node_idx = parent.children.index(node) + next_node = None + if len(parent.children) > node_idx + 1: + next_node = parent.children[node_idx + 1] + if_fission(node, distribute=True) + end = len(parent.children) + if next_node: + end = parent.children.index(next_node) + for child in parent.children[node_idx: end]: + if len(child.children) > 1: + frontier.append(child) + # frontier.extend(parent.children[node_idx: end]) + elif isinstance(node, tn.ScheduleTreeScope): + frontier.extend(node.children) + + return tree diff --git a/dace/sdfg/analysis/schedule_tree/transformations.py b/dace/sdfg/analysis/schedule_tree/transformations.py index 1dd34b23dd..0457e90c1f 100644 --- a/dace/sdfg/analysis/schedule_tree/transformations.py +++ b/dace/sdfg/analysis/schedule_tree/transformations.py @@ -150,24 +150,30 @@ def map_fission(map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode) -> bo return True -def if_fission(if_scope: tnodes.IfScope, tree: tnodes.ScheduleTreeNode) -> bool: +def if_fission(if_scope: tnodes.IfScope, distribute: bool = False) -> bool: + + from dace.sdfg.nodes import CodeBlock parent_scope = if_scope.parent idx = parent_scope.children.index(if_scope) + + # Check transformation conditions + # Scope must not have subsequent elif or else scopes if len(parent_scope.children) > idx + 1 and isinstance(parent_scope.children[idx+1], (tnodes.ElifScope, tnodes.ElseScope)): return False + # Apply transformations partition = tutils.partition_scope_body(if_scope) - if len(partition) < 2: - return False - parent_scope.children.pop(idx) while len(partition) > 0: child_scope = partition.pop() - if not isinstance(child_scope, list): - child_scope = [child_scope] - scope = tnodes.IfScope(if_scope.sdfg, False, child_scope, deepcopy(if_scope.condition)) + if isinstance(child_scope, list) and len(child_scope) == 1 and isinstance(child_scope[0], tnodes.IfScope) and distribute: + scope = tnodes.IfScope(if_scope.sdfg, False, child_scope[0].children, CodeBlock(f"{if_scope.condition.as_string} and {child_scope[0].condition.as_string}")) + else: + if not isinstance(child_scope, list): + child_scope = [child_scope] + scope = tnodes.IfScope(if_scope.sdfg, False, child_scope, deepcopy(if_scope.condition)) scope.parent = parent_scope parent_scope.children.insert(idx, scope) From 754ceea7192b2b0a48fbebb5e51ff0d604eaab4a Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Fri, 13 Jan 2023 16:52:44 +0100 Subject: [PATCH 32/98] Split out SDFG "dealiasing" of `remove_name_collisions`. Fixed read-memlet replacement in interstate edges. --- .../analysis/schedule_tree/sdfg_to_tree.py | 79 +++++++++++-------- 1 file changed, 44 insertions(+), 35 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 5eff1d89b0..9ba1afa133 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -17,6 +17,39 @@ import sys +def dealias_sdfg(sdfg: SDFG): + for nsdfg in sdfg.all_sdfgs_recursive(): + + if not nsdfg.parent: + continue + + replacements: Dict[str, str] = {} + + parent_sdfg = nsdfg.parent_sdfg + parent_state = nsdfg.parent + parent_node = nsdfg.parent_nsdfg_node + + for name, desc in nsdfg.arrays.items(): + if desc.transient: + continue + for edge in parent_state.edges_by_connector(parent_node, name): + parent_name = edge.data.data + assert parent_name in parent_sdfg.arrays + if name != parent_name: + replacements[name] = parent_name + break + + if replacements: + nsdfg.replace_dict(replacements) + parent_node.in_connectors = {replacements[c] if c in replacements else c: t for c, t in parent_node.in_connectors.items()} + parent_node.out_connectors = {replacements[c] if c in replacements else c: t for c, t in parent_node.out_connectors.items()} + for e in parent_state.all_edges(parent_node): + if e.src_conn in replacements: + e._src_conn = replacements[e.src_conn] + elif e.dst_conn in replacements: + e._dst_conn = replacements[e.dst_conn] + + def populate_containers(scope: tn.ScheduleTreeScope, defined_arrays: Set[str] = None): defined_arrays = defined_arrays or set() if scope.top_level: @@ -70,10 +103,17 @@ def replace_memlets(sdfg: SDFG, array_mapping: Dict[str, Memlet]): if e.data.data in array_mapping: e.data = unsqueeze_memlet(e.data, array_mapping[e.data.data]) for e in sdfg.edges(): + repl_dict = dict() syms = e.data.read_symbols() + for memlet in e.data.get_read_memlets(sdfg.arrays): + if memlet.data in array_mapping: + repl_dict[str(memlet)] = unsqueeze_memlet(memlet, array_mapping[memlet.data]) + if memlet.data in syms: + syms.remove(memlet.data) for s in syms: if s in array_mapping: - e.data.replace(s, str(array_mapping[s])) + repl_dict[s] = str(array_mapping[s]) + e.data.replace_dict(repl_dict) @@ -102,28 +142,8 @@ def remove_name_collisions(sdfg: SDFG): # Rename duplicate data containers for name, desc in nsdfg.arrays.items(): - # TODO: Is it better to do this while parsing the SDFG? - pdesc = desc - pnode = parent_node - csdfg = nsdfg - cname = name - while pnode is not None and not pdesc.transient: - parent_state = csdfg.parent - parent_sdfg = csdfg.parent_sdfg - edge = list(parent_state.edges_by_connector(parent_node, cname))[0] - path = parent_state.memlet_path(edge) - if path[0].src is parent_node: - parent_name = path[-1].dst.data - else: - parent_name = path[0].src.data - pdesc = parent_sdfg.arrays[parent_name] - csdfg = parent_sdfg - pnode = csdfg.parent_nsdfg_node - cname = parent_name - # if pnode is None and not pdesc.transient and name != cname: - if pnode is None and name != cname: - replacements[name] = cname - name = cname + + if not desc.transient: continue if name in identifiers_seen: @@ -155,17 +175,6 @@ def remove_name_collisions(sdfg: SDFG): # If there is a name collision, replace all uses of the old names with the new names if replacements: nsdfg.replace_dict(replacements) - # TODO: Should this be handled differently? - # Replacing connector names - # Replacing edge connector names - if nsdfg.parent_sdfg: - nsdfg.parent_nsdfg_node.in_connectors = {replacements[c] if c in replacements else c: t for c, t in nsdfg.parent_nsdfg_node.in_connectors.items()} - nsdfg.parent_nsdfg_node.out_connectors = {replacements[c] if c in replacements else c: t for c, t in nsdfg.parent_nsdfg_node.out_connectors.items()} - for e in nsdfg.parent.all_edges(nsdfg.parent_nsdfg_node): - if e.src_conn in replacements: - e._src_conn = replacements[e.src_conn] - elif e.dst_conn in replacements: - e._dst_conn = replacements[e.dst_conn] def _make_view_node(state: SDFGState, edge: gr.MultiConnectorEdge[Memlet], view_name: str, @@ -404,6 +413,7 @@ def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) if toplevel: # Top-level SDFG preparation (only perform once) # Handle name collisions (in arrays, state labels, symbols) + dealias_sdfg(sdfg) remove_name_collisions(sdfg) ############################# @@ -482,7 +492,6 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche if node.first_state is not None: result = [tn.StateLabel(state=node.first_state)] + result - # result = [tn.StateLabel(sdfg=node.first_state.parent, state=node.first_state)] + result return result From 17d63196287a577462091f2e33a2d9c818c874ca Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Fri, 13 Jan 2023 16:53:58 +0100 Subject: [PATCH 33/98] Added pass to fission scopes. --- dace/sdfg/analysis/schedule_tree/passes.py | 102 ++++++++++++--------- 1 file changed, 59 insertions(+), 43 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/passes.py b/dace/sdfg/analysis/schedule_tree/passes.py index f7ee030035..81cedffea2 100644 --- a/dace/sdfg/analysis/schedule_tree/passes.py +++ b/dace/sdfg/analysis/schedule_tree/passes.py @@ -3,9 +3,9 @@ Assortment of passes for schedule trees. """ -import copy from dace import data as dt, Memlet, subsets as sbs, symbolic as sym from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dataclasses import dataclass from typing import Set @@ -136,53 +136,69 @@ def canonicalize_if(tree: tn.ScheduleTreeScope): """ from dace.sdfg.nodes import CodeBlock - from dace.sdfg.analysis.schedule_tree.transformations import if_fission class CanonicalizeIf(tn.ScheduleNodeTransformer): + def visit(self, node: tn.ScheduleTreeNode): + + if isinstance(node, (tn.ElifScope, tn.ElseScope)): + parent = node.parent + assert node in parent.children + node_idx = parent.children.index(node) + + conditions = [] + for curr_node in reversed(parent.children[:node_idx]): + conditions.append(curr_node.condition) + if isinstance(curr_node, tn.IfScope): + break + condition = f"not ({' or '.join([f'({c.as_string})' for c in conditions])})" + if isinstance(node, tn.ElifScope): + condition = f"{condition} and {node.condition.as_string}" + new_node = tn.IfScope(node.sdfg, node.top_level, node.children, CodeBlock(condition)) + new_node.parent = parent + else: + new_node = node + + return self.generic_visit(new_node) + + return CanonicalizeIf().visit(tree) + + +def fission_scopes(node: tn.ScheduleTreeScope): + + from dace.sdfg.analysis.schedule_tree.transformations import loop_fission, if_fission + @dataclass + class FissionScopes(tn.ScheduleNodeTransformer): + + tree: tn.ScheduleTreeScope + + def visit_IfScope(self, node: tn.IfScope): + return if_fission(node, assume_canonical=True, distribute=True) + + def visit_ForScope(self, node: tn.ForScope): + return loop_fission(node, self.tree) + + def visit_MapScope(self, node: tn.MapScope): + return loop_fission(node, self.tree) + def visit(self, node: tn.ScheduleTreeNode): - if not isinstance(node, (tn.ElifScope, tn.ElseScope)): + node = self.generic_visit(node) + if isinstance(node, (tn.IfScope, tn.ForScope, tn.MapScope)): return super().visit(node) - - parent = node.parent - node_idx = parent.children.index(node) - - conditions = [] - for curr_node in reversed(parent.children[:node_idx]): - conditions.append(curr_node.condition) - if isinstance(curr_node, tn.IfScope): - break - condition = f"not ({' or '.join([f'({c.as_string})' for c in conditions])})" - if isinstance(node, tn.ElifScope): - condition = f"{condition} and {node.condition.as_string}" - new_node = tn.IfScope(node.sdfg, node.top_level, node.children, CodeBlock(condition)) - new_node.parent = node.parent + return node - return self.generic_visit(new_node) + return FissionScopes(node).visit(node) - CanonicalizeIf().visit(tree) - print(tree.as_string()) - print() - - frontier = list(tree.children) - while frontier: - node = frontier.pop() - if isinstance(node, tn.IfScope): - parent = node.parent - node_idx = parent.children.index(node) - next_node = None - if len(parent.children) > node_idx + 1: - next_node = parent.children[node_idx + 1] - if_fission(node, distribute=True) - end = len(parent.children) - if next_node: - end = parent.children.index(next_node) - for child in parent.children[node_idx: end]: - if len(child.children) > 1: - frontier.append(child) - # frontier.extend(parent.children[node_idx: end]) - elif isinstance(node, tn.ScheduleTreeScope): - frontier.extend(node.children) - - return tree + +def validate(node: tn.ScheduleTreeNode) -> bool: + + if isinstance(node, tn.ScheduleTreeScope): + if any(child.parent is not node for child in node.children): + return False + if all(validate(child) for child in node.children): + return True + else: + return False + + return True From 83ae8171d65d3fff71ab463b81a0e39c8dd54196 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Fri, 13 Jan 2023 16:55:58 +0100 Subject: [PATCH 34/98] Reworked map-fission to loop-fission (support Maps and ForLoops). --- .../analysis/schedule_tree/transformations.py | 171 +++++++++++++----- 1 file changed, 128 insertions(+), 43 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/transformations.py b/dace/sdfg/analysis/schedule_tree/transformations.py index 0457e90c1f..2f2a36a6d9 100644 --- a/dace/sdfg/analysis/schedule_tree/transformations.py +++ b/dace/sdfg/analysis/schedule_tree/transformations.py @@ -1,9 +1,10 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. -from copy import deepcopy -from dace import data as dt, Memlet, SDFG +import copy +from dace import data as dt, Memlet, SDFG, subsets from dace.sdfg.analysis.schedule_tree import treenodes as tnodes from dace.sdfg.analysis.schedule_tree import utils as tutils -from typing import Dict, Set +import re +from typing import Dict, List, Set, Union _dataflow_nodes = (tnodes.ViewNode, tnodes.RefSetNode, tnodes.CopyNode, tnodes.DynScopeCopyNode, tnodes.TaskletNode, tnodes.LibraryCall) @@ -16,25 +17,56 @@ def _update_memlets(data: Dict[str, dt.Data], memlets: Dict[str, Memlet], index: memlets[conn] = Memlet(data=memlet.data, subset=subset) -# def _augment_data(data: Dict[str, dt.Data], map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode, sdfg: SDFG): -def _augment_data(data: Set[str], map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode, sdfg: SDFG): +def _augment_data(data: Set[str], loop: Union[tnodes.MapScope, tnodes.ForScope], tree: tnodes.ScheduleTreeNode): - # Generate map-related indices, sizes and, strides - map = map_scope.node.map - index = ", ".join(f"{p} - {r[0]}" for p, r in zip(map.params, map.range)) - size = map.range.size() + # Generate loop-related indices, sizes and, strides + if isinstance(loop, tnodes.MapScope): + map = loop.node.map + index = ", ".join(f"{p}/{r[2]}-{r[0]}" if r[2] != 1 else f"{p}-{r[0]}" for p, r in zip(map.params, map.range)) + size = map.range.size() + else: + itervar = loop.header.itervar + start = loop.header.init + # NOTE: Condition expression may be inside parentheses + par = re.search(f"^\s*(\(?)\s*{itervar}", loop.header.condition.as_string).group(1) == '(' + if par: + stop_match = re.search(f"\(\s*{itervar}\s*([<>=]+)\s*(.+)\s*\)", loop.header.condition.as_string) + else: + stop_match = re.search(f"{itervar}\s*([<>=]+)\s*(.+)", loop.header.condition.as_string) + stop_op = stop_match.group(1) + assert stop_op in ("<", "<=", ">", ">=") + stop = stop_match.group(2) + # NOTE: Update expression may be inside parentheses + par = re.search(f"^\s*(\(?)\s*{itervar}", loop.header.update).group(1) == '(' + if par: + step_match = re.search(f"\(\s*{itervar}\s*([(*+-/%)]+)\s*([a-zA-Z0-9_]+)\s*\)", loop.header.update) + else: + step_match = re.search(f"{itervar}\s*([(*+-/%)]+)\s*([a-zA-Z0-9_]+)", loop.header.update) + try: + step_op = step_match.group(1) + step = step_match.group(2) + if step_op == '+': + step = int(step) + index = f"{itervar}/{step}-{start}" if step != 1 else f"{itervar}-{start}" + else: + raise ValueError + except (AttributeError, ValueError): + step = 1 if '<' in stop_op else -1 + index = itervar + if "=" in stop_op: + stop = f"{stop} + ({step})" + size = subsets.Range.from_string(f"{start}:{stop}:{step}").size() + strides = [1] * len(size) for i in range(len(size) - 2, -1, -1): strides[i] = strides[i+1] * size[i+1] # Augment data descriptors replace = dict() - # for name, nsdfg in data.items(): for name in data: - # desc = nsdfg.arrays[name] - desc = map_scope.containers[name] + desc = loop.containers[name] if isinstance(desc, dt.Scalar): - # nsdfg.arrays[name] = dt.Array(desc.dtype, size, True, storage=desc.storage) + desc = dt.Array(desc.dtype, size, True, storage=desc.storage) replace[name] = True else: @@ -43,15 +75,8 @@ def _augment_data(data: Set[str], map_scope: tnodes.MapScope, tree: tnodes.Sched new_strides = [s * mult for s in strides] desc.strides = (*new_strides, *desc.strides) replace[name] = False - del map_scope.containers[name] - map_scope.parent.containers[name] = desc - # if sdfg.parent: - # nsdfg_node = nsdfg.parent_nsdfg_node - # nsdfg_state = nsdfg.parent - # nsdfg_node.out_connectors = {**nsdfg_node.out_connectors, name: None} - # sdfg.arrays[name] = deepcopy(nsdfg.arrays[name]) - # access = nsdfg_state.add_access(name) - # nsdfg_state.add_edge(nsdfg_node, name, access, None, Memlet.from_array(name, sdfg.arrays[name])) + del loop.containers[name] + loop.parent.containers[name] = desc # Update memlets frontier = list(tree.children) @@ -69,6 +94,61 @@ def _augment_data(data: Set[str], map_scope: tnodes.MapScope, tree: tnodes.Sched frontier.extend(node.children) + +def loop_fission(loop: Union[tnodes.MapScope, tnodes.ForScope], tree: tnodes.ScheduleTreeNode) -> List[Union[tnodes.MapScope, tnodes.ForScope]]: + """ + Applies the LoopFission transformation to the input MapScope or ForScope. + + :param loop: The MapScope or ForScope. + :param tree: The ScheduleTree. + :return: True if the transformation applies successfully, otherwise False. + """ + + sdfg = loop.sdfg + + #################################### + # Check if LoopFission can be applied + + # Basic check: cannot fission an empty MapScope/ForScope or one that has a single child. + partition = tutils.partition_scope_body(loop) + if len(partition) < 2: + return [loop] + + data_to_augment = set() + frontier = list(partition) + while len(frontier) > 0: + scope = frontier.pop() + if isinstance(scope, _dataflow_nodes): + try: + for _, memlet in scope.out_memlets.items(): + if memlet.data in loop.containers: + data_to_augment.add(memlet.data) + except AttributeError: + if scope.target in loop.containers: + data_to_augment.add(scope.target) + if hasattr(scope, 'children'): + frontier.extend(scope.children) + _augment_data(data_to_augment, loop, tree) + + new_scopes = [] + while partition: + child = partition.pop(0) + if not isinstance(child, list): + child = [child] + if isinstance(loop, tnodes.MapScope): + scope = tnodes.MapScope(sdfg, False, child, copy.deepcopy(loop.node)) + else: + scope = tnodes.ForScope(sdfg, False, child, copy.copy(loop.header)) + for child in scope.children: + child.parent = scope + if isinstance(child, tnodes.ScheduleTreeScope): + scope.containers.update(child.containers) + scope.parent = loop.parent + new_scopes.append(scope) + + return new_scopes + + def map_fission(map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode) -> bool: """ Applies the MapFission transformation to the input MapScope. @@ -134,7 +214,7 @@ def map_fission(map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode) -> bo # data_to_augment[scope.target] = scope.sdfg if hasattr(scope, 'children'): frontier.extend(scope.children) - _augment_data(data_to_augment, map_scope, tree, sdfg) + _augment_data(data_to_augment, map_scope, tree) parent_scope = map_scope.parent idx = parent_scope.children.index(map_scope) @@ -143,41 +223,46 @@ def map_fission(map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode) -> bo child_scope = partition.pop() if not isinstance(child_scope, list): child_scope = [child_scope] - scope = tnodes.MapScope(sdfg, False, child_scope, deepcopy(map_scope.node)) + scope = tnodes.MapScope(sdfg, False, child_scope, copy.deepcopy(map_scope.node)) scope.parent = parent_scope parent_scope.children.insert(idx, scope) return True -def if_fission(if_scope: tnodes.IfScope, distribute: bool = False) -> bool: +def if_fission(if_scope: tnodes.IfScope, assume_canonical: bool = False, distribute: bool = False) -> List[tnodes.IfScope]: from dace.sdfg.nodes import CodeBlock - parent_scope = if_scope.parent - idx = parent_scope.children.index(if_scope) - # Check transformation conditions # Scope must not have subsequent elif or else scopes - if len(parent_scope.children) > idx + 1 and isinstance(parent_scope.children[idx+1], - (tnodes.ElifScope, tnodes.ElseScope)): - return False + if not assume_canonical: + idx = if_scope.parent.children.index(if_scope) + if len(if_scope.parent.children) > idx + 1 and isinstance(if_scope.parent.children[idx+1], + (tnodes.ElifScope, tnodes.ElseScope)): + return [if_scope] + if len(if_scope.children) < 2 and not (isinstance(if_scope.children[0], tnodes.IfScope) and distribute): + return [if_scope] - # Apply transformations + new_scopes = [] partition = tutils.partition_scope_body(if_scope) - parent_scope.children.pop(idx) - while len(partition) > 0: - child_scope = partition.pop() - if isinstance(child_scope, list) and len(child_scope) == 1 and isinstance(child_scope[0], tnodes.IfScope) and distribute: - scope = tnodes.IfScope(if_scope.sdfg, False, child_scope[0].children, CodeBlock(f"{if_scope.condition.as_string} and {child_scope[0].condition.as_string}")) + while partition: + child = partition.pop(0) + if isinstance(child, list) and len(child) == 1 and isinstance(child[0], tnodes.IfScope) and distribute: + scope = tnodes.IfScope(if_scope.sdfg, False, child[0].children, CodeBlock(f"{if_scope.condition.as_string} and {child[0].condition.as_string}")) + scope.containers.update(child[0].containers) else: - if not isinstance(child_scope, list): - child_scope = [child_scope] - scope = tnodes.IfScope(if_scope.sdfg, False, child_scope, deepcopy(if_scope.condition)) - scope.parent = parent_scope - parent_scope.children.insert(idx, scope) + if not isinstance(child, list): + child = [child] + scope = tnodes.IfScope(if_scope.sdfg, False, child, copy.deepcopy(if_scope.condition)) + for child in scope.children: + child.parent = scope + if isinstance(child, tnodes.ScheduleTreeScope): + scope.containers.update(child.containers) + scope.parent = if_scope.parent + new_scopes.append(scope) - return True + return new_scopes def wcr_to_reduce(map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode) -> bool: From 52831e3f8f1b890999aacd5248855f2404e9e617 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Fri, 13 Jan 2023 16:57:06 +0100 Subject: [PATCH 35/98] Children must always point to the correct parent. --- dace/sdfg/analysis/schedule_tree/treenodes.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index d9fb876968..57538b2747 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -67,6 +67,9 @@ def __init__(self, sdfg: Optional[SDFG] = None, top_level: Optional[bool] = Fals self.sdfg = sdfg self.top_level = top_level self.children = children or [] + if self.children: + for child in children: + child.parent = self # self.__post_init__() # for child in children: # child.parent = self @@ -568,5 +571,7 @@ def generic_visit(self, node: ScheduleTreeNode): new_values.extend(value) continue new_values.append(value) + for val in new_values: + val.parent = node node.children[:] = new_values return node From eed6777eafd507c3a00e140b4dc1ce932033b7e2 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Thu, 19 Jan 2023 17:32:15 +0100 Subject: [PATCH 36/98] WIP: Fission with AssignNodes in body --- .../analysis/schedule_tree/sdfg_to_tree.py | 5 +- .../analysis/schedule_tree/transformations.py | 56 +++++++++++++++++-- dace/sdfg/analysis/schedule_tree/treenodes.py | 14 ++++- 3 files changed, 66 insertions(+), 9 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 9ba1afa133..0ed1ee779b 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -54,6 +54,9 @@ def populate_containers(scope: tn.ScheduleTreeScope, defined_arrays: Set[str] = defined_arrays = defined_arrays or set() if scope.top_level: scope.containers = {name: copy.deepcopy(desc) for name, desc in scope.sdfg.arrays.items() if not desc.transient} + scope.symbols = dict() + for sdfg in scope.sdfg.all_sdfgs_recursive(): + scope.symbols.update(sdfg.symbols) defined_arrays = set(scope.containers.keys()) _, defined_arrays = scope.define_arrays(0, defined_arrays) for child in scope.children: @@ -445,7 +448,7 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche if e not in parent.assignments_to_ignore: for aname, aval in e.data.assignments.items(): - edge_body.append(tn.AssignNode(name=aname, value=CodeBlock(aval))) + edge_body.append(tn.AssignNode(name=aname, value=CodeBlock(aval), edge=InterstateEdge(assignments={aname: aval}))) if not parent.sequential: if e not in parent.gotos_to_ignore: diff --git a/dace/sdfg/analysis/schedule_tree/transformations.py b/dace/sdfg/analysis/schedule_tree/transformations.py index 2f2a36a6d9..ccc5413118 100644 --- a/dace/sdfg/analysis/schedule_tree/transformations.py +++ b/dace/sdfg/analysis/schedule_tree/transformations.py @@ -1,6 +1,7 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. import copy -from dace import data as dt, Memlet, SDFG, subsets +from dace import data as dt, dtypes, Memlet, SDFG, subsets +from dace.sdfg import nodes as dnodes from dace.sdfg.analysis.schedule_tree import treenodes as tnodes from dace.sdfg.analysis.schedule_tree import utils as tutils import re @@ -15,7 +16,20 @@ def _update_memlets(data: Dict[str, dt.Data], memlets: Dict[str, Memlet], index: if memlet.data in data: subset = index if replace[memlet.data] else f"{index}, {memlet.subset}" memlets[conn] = Memlet(data=memlet.data, subset=subset) - + # else: + # repl_dict = dict() + # for s in memlet.subset.free_symbols: + # if s in data: + # repl_dict[s] = f"{s}({index})" + # if repl_dict: + # memlet.subset.replace(repl_dict) + # if memlet.other_subset: + # repl_dict = dict() + # for s in memlet.other_subset.free_symbols: + # if s in data: + # repl_dict[s] = f"{s}({index})" + # if repl_dict: + # memlet.other_subset.replace(repl_dict) def _augment_data(data: Set[str], loop: Union[tnodes.MapScope, tnodes.ForScope], tree: tnodes.ScheduleTreeNode): @@ -115,6 +129,7 @@ def loop_fission(loop: Union[tnodes.MapScope, tnodes.ForScope], tree: tnodes.Sch return [loop] data_to_augment = set() + assignments = dict() frontier = list(partition) while len(frontier) > 0: scope = frontier.pop() @@ -126,15 +141,46 @@ def loop_fission(loop: Union[tnodes.MapScope, tnodes.ForScope], tree: tnodes.Sch except AttributeError: if scope.target in loop.containers: data_to_augment.add(scope.target) - if hasattr(scope, 'children'): + elif isinstance(scope, tnodes.AssignNode): + symbol = tree.symbols[scope.name] + loop.containers[scope.name] = dt.Scalar(symbol.dtype, transient=True) + data_to_augment.add(scope.name) + repl_dict = {scope.name: '__out'} + out_memlets = {'__out': Memlet(data=scope.name, subset='0')} + in_memlets = dict() + for i, memlet in enumerate(scope.edge.get_read_memlets(scope.parent.sdfg.arrays)): + repl_dict[str(memlet)] = f'__in{i}' + in_memlets[f'__in{i}'] = memlet + scope.edge.replace_dict(repl_dict) + tasklet = dnodes.Tasklet('some_label', in_memlets.keys(), {'__out'}, + f"__out = {scope.edge.assignments['__out']}") + tnode = tnodes.TaskletNode(tasklet, in_memlets, out_memlets) + tnode.parent = loop + idx = loop.children.index(scope) + loop.children[idx] = tnode + idx = partition.index(scope) + partition[idx] = tnode + assignments[scope.name] = (scope.value, scope.edge, idx) + elif hasattr(scope, 'children'): frontier.extend(scope.children) _augment_data(data_to_augment, loop, tree) new_scopes = [] - while partition: - child = partition.pop(0) + # while partition: + # child = partition.pop(0) + for i, child in enumerate(partition): if not isinstance(child, list): child = [child] + + for c in list(child): + idx = child.index(c) + # Reverse access? + for name, (value, edge, index) in assignments.items(): + if index == i: + continue + if c.is_data_used(name, True): + child.insert(idx, tnodes.AssignNode(name, copy.deepcopy(value), copy.deepcopy(edge))) + if isinstance(loop, tnodes.MapScope): scope = tnodes.MapScope(sdfg, False, child, copy.deepcopy(loop.node)) else: diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 57538b2747..7a692fcb5f 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -5,8 +5,9 @@ from dace.codegen import control_flow as cf from dace.dtypes import TYPECLASS_TO_STRING from dace.properties import CodeBlock -from dace.sdfg import SDFG +from dace.sdfg import SDFG, InterstateEdge from dace.sdfg.state import SDFGState +from dace.symbolic import symbol from dace.memlet import Memlet from typing import Dict, List, Optional, Set, Tuple, Union @@ -47,7 +48,7 @@ def define_arrays(self, indent: int, defined_arrays: Set[str]) -> Tuple[str, Set # defined_arrays |= undefined_arrays.keys() # return string, defined_arrays - def is_data_used(self, name: str) -> bool: + def is_data_used(self, name: str, include_symbols: bool = False) -> bool: pass # for child in self.children: # if child.is_data_used(name): @@ -61,6 +62,7 @@ class ScheduleTreeScope(ScheduleTreeNode): top_level: bool children: List['ScheduleTreeNode'] containers: Optional[Dict[str, data.Data]] = field(default_factory=dict, init=False) + symbols: Optional[Dict[str, symbol]] = field(default_factory=dict, init=False) # def __init__(self, sdfg: Optional[SDFG] = None, top_level: Optional[bool] = False, children: Optional[List['ScheduleTreeNode']] = None): def __init__(self, sdfg: Optional[SDFG] = None, top_level: Optional[bool] = False, children: Optional[List['ScheduleTreeNode']] = None): @@ -202,6 +204,7 @@ class AssignNode(ScheduleTreeNode): """ name: str value: CodeBlock + edge: InterstateEdge def as_string(self, indent: int = 0): return indent * INDENTATION + f'assign {self.name} = {self.value.as_string}' @@ -389,9 +392,14 @@ def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[s code = self.node.code.as_string.replace('\n', '\\n') return string + indent * INDENTATION + f"dace.tree.tasklet(label='{self.node.label}', inputs={{{in_memlets}}}, outputs={{{out_memlets}}}, code='{code}', language=dace.{self.node.language})", defined_arrays - def is_data_used(self, name: str) -> bool: + def is_data_used(self, name: str, include_symbols: bool = False) -> bool: used_data = set([memlet.data for memlet in self.in_memlets.values()]) used_data |= set([memlet.data for memlet in self.out_memlets.values()]) + if include_symbols: + for memlet in self.in_memlets.values(): + used_data |= memlet.subset.free_symbols + if memlet.other_subset: + used_data |= memlet.other_subset.free_symbols return name in used_data From add80bf8d7b4d097e7feabc49bf766dfcfca9c93 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Sat, 21 Jan 2023 16:04:50 +0100 Subject: [PATCH 37/98] Updates for Fission with AssignNodes. --- .../analysis/schedule_tree/transformations.py | 97 +++++++++++++------ 1 file changed, 69 insertions(+), 28 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/transformations.py b/dace/sdfg/analysis/schedule_tree/transformations.py index ccc5413118..fd7bfad842 100644 --- a/dace/sdfg/analysis/schedule_tree/transformations.py +++ b/dace/sdfg/analysis/schedule_tree/transformations.py @@ -5,34 +5,13 @@ from dace.sdfg.analysis.schedule_tree import treenodes as tnodes from dace.sdfg.analysis.schedule_tree import utils as tutils import re -from typing import Dict, List, Set, Union +from typing import Dict, List, Set, Tuple, Union _dataflow_nodes = (tnodes.ViewNode, tnodes.RefSetNode, tnodes.CopyNode, tnodes.DynScopeCopyNode, tnodes.TaskletNode, tnodes.LibraryCall) -def _update_memlets(data: Dict[str, dt.Data], memlets: Dict[str, Memlet], index: str, replace: Dict[str, bool]): - for conn, memlet in memlets.items(): - if memlet.data in data: - subset = index if replace[memlet.data] else f"{index}, {memlet.subset}" - memlets[conn] = Memlet(data=memlet.data, subset=subset) - # else: - # repl_dict = dict() - # for s in memlet.subset.free_symbols: - # if s in data: - # repl_dict[s] = f"{s}({index})" - # if repl_dict: - # memlet.subset.replace(repl_dict) - # if memlet.other_subset: - # repl_dict = dict() - # for s in memlet.other_subset.free_symbols: - # if s in data: - # repl_dict[s] = f"{s}({index})" - # if repl_dict: - # memlet.other_subset.replace(repl_dict) - -def _augment_data(data: Set[str], loop: Union[tnodes.MapScope, tnodes.ForScope], tree: tnodes.ScheduleTreeNode): - +def _get_loop_size(loop: Union[tnodes.MapScope, tnodes.ForScope]) -> Tuple[str, list, list]: # Generate loop-related indices, sizes and, strides if isinstance(loop, tnodes.MapScope): map = loop.node.map @@ -74,6 +53,62 @@ def _augment_data(data: Set[str], loop: Union[tnodes.MapScope, tnodes.ForScope], strides = [1] * len(size) for i in range(len(size) - 2, -1, -1): strides[i] = strides[i+1] * size[i+1] + + return index, size, strides + + +def _update_memlets(data: Dict[str, dt.Data], memlets: Dict[str, Memlet], index: str, replace: Dict[str, bool]): + for conn, memlet in memlets.items(): + if memlet.data in data: + subset = index if replace[memlet.data] else f"{index}, {memlet.subset}" + memlets[conn] = Memlet(data=memlet.data, subset=subset) + + +def _augment_data(data: Set[str], loop: Union[tnodes.MapScope, tnodes.ForScope], tree: tnodes.ScheduleTreeNode): + + # # Generate loop-related indices, sizes and, strides + # if isinstance(loop, tnodes.MapScope): + # map = loop.node.map + # index = ", ".join(f"{p}/{r[2]}-{r[0]}" if r[2] != 1 else f"{p}-{r[0]}" for p, r in zip(map.params, map.range)) + # size = map.range.size() + # else: + # itervar = loop.header.itervar + # start = loop.header.init + # # NOTE: Condition expression may be inside parentheses + # par = re.search(f"^\s*(\(?)\s*{itervar}", loop.header.condition.as_string).group(1) == '(' + # if par: + # stop_match = re.search(f"\(\s*{itervar}\s*([<>=]+)\s*(.+)\s*\)", loop.header.condition.as_string) + # else: + # stop_match = re.search(f"{itervar}\s*([<>=]+)\s*(.+)", loop.header.condition.as_string) + # stop_op = stop_match.group(1) + # assert stop_op in ("<", "<=", ">", ">=") + # stop = stop_match.group(2) + # # NOTE: Update expression may be inside parentheses + # par = re.search(f"^\s*(\(?)\s*{itervar}", loop.header.update).group(1) == '(' + # if par: + # step_match = re.search(f"\(\s*{itervar}\s*([(*+-/%)]+)\s*([a-zA-Z0-9_]+)\s*\)", loop.header.update) + # else: + # step_match = re.search(f"{itervar}\s*([(*+-/%)]+)\s*([a-zA-Z0-9_]+)", loop.header.update) + # try: + # step_op = step_match.group(1) + # step = step_match.group(2) + # if step_op == '+': + # step = int(step) + # index = f"{itervar}/{step}-{start}" if step != 1 else f"{itervar}-{start}" + # else: + # raise ValueError + # except (AttributeError, ValueError): + # step = 1 if '<' in stop_op else -1 + # index = itervar + # if "=" in stop_op: + # stop = f"{stop} + ({step})" + # size = subsets.Range.from_string(f"{start}:{stop}:{step}").size() + + # strides = [1] * len(size) + # for i in range(len(size) - 2, -1, -1): + # strides[i] = strides[i+1] * size[i+1] + + index, size, strides = _get_loop_size(loop) # Augment data descriptors replace = dict() @@ -127,6 +162,8 @@ def loop_fission(loop: Union[tnodes.MapScope, tnodes.ForScope], tree: tnodes.Sch partition = tutils.partition_scope_body(loop) if len(partition) < 2: return [loop] + + index, _, _ = _get_loop_size(loop) data_to_augment = set() assignments = dict() @@ -143,10 +180,10 @@ def loop_fission(loop: Union[tnodes.MapScope, tnodes.ForScope], tree: tnodes.Sch data_to_augment.add(scope.target) elif isinstance(scope, tnodes.AssignNode): symbol = tree.symbols[scope.name] - loop.containers[scope.name] = dt.Scalar(symbol.dtype, transient=True) - data_to_augment.add(scope.name) + loop.containers[f"{scope.name}_arr"] = dt.Scalar(symbol.dtype, transient=True) + data_to_augment.add(f"{scope.name}_arr") repl_dict = {scope.name: '__out'} - out_memlets = {'__out': Memlet(data=scope.name, subset='0')} + out_memlets = {'__out': Memlet(data=f"{scope.name}_arr", subset='0')} in_memlets = dict() for i, memlet in enumerate(scope.edge.get_read_memlets(scope.parent.sdfg.arrays)): repl_dict[str(memlet)] = f'__in{i}' @@ -160,10 +197,14 @@ def loop_fission(loop: Union[tnodes.MapScope, tnodes.ForScope], tree: tnodes.Sch loop.children[idx] = tnode idx = partition.index(scope) partition[idx] = tnode - assignments[scope.name] = (scope.value, scope.edge, idx) + edge = copy.deepcopy(scope.edge) + edge.assignments['__out'] = f"{scope.name}_arr[{index}]" + assignments[scope.name] = (dnodes.CodeBlock(f"{scope.name}[{index}]"), edge, idx) elif hasattr(scope, 'children'): frontier.extend(scope.children) _augment_data(data_to_augment, loop, tree) + print(data_to_augment) + new_scopes = [] # while partition: @@ -179,7 +220,7 @@ def loop_fission(loop: Union[tnodes.MapScope, tnodes.ForScope], tree: tnodes.Sch if index == i: continue if c.is_data_used(name, True): - child.insert(idx, tnodes.AssignNode(name, copy.deepcopy(value), copy.deepcopy(edge))) + child.insert(idx, tnodes.AssignNode(f"{name}", copy.deepcopy(value), copy.deepcopy(edge))) if isinstance(loop, tnodes.MapScope): scope = tnodes.MapScope(sdfg, False, child, copy.deepcopy(loop.node)) From 5a8c2677f71d9f6f9c5fc24983eab1b91729c59f Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Sat, 21 Jan 2023 18:17:04 +0100 Subject: [PATCH 38/98] Fix for views becoming program parameters and/or return arrays when they shouldn't. --- dace/frontend/python/newast.py | 40 ++++++++++++------- .../schedule_tree/conversion_test.py | 0 2 files changed, 26 insertions(+), 14 deletions(-) rename tests/{sdfg => }/schedule_tree/conversion_test.py (100%) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 87f91bc3df..38be6f01f2 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -1198,6 +1198,9 @@ def parse_program(self, program: ast.FunctionDef, is_tasklet: bool = False): for vname, arrname in self.variables.items(): if vname.startswith('__return'): if isinstance(self.sdfg.arrays[arrname], data.View): + # If it is a registered view, ignore. + if vname in self.views and self.views[vname][0].startswith('__return'): + continue # In case of a view, make a copy # NOTE: If we are at the top level SDFG (not always clear), # and it is a View of an input array, can we return a NumPy @@ -1233,7 +1236,8 @@ def parse_program(self, program: ast.FunctionDef, is_tasklet: bool = False): for arrname, arr in self.sdfg.arrays.items(): # Return values become non-transient (accessible by the outside) - if arrname.startswith('__return'): + if arrname.startswith('__return') and not ( + vname in self.views and self.views[vname][0].startswith('__return')): arr.transient = False self.outputs[arrname] = (None, Memlet.from_array(arrname, arr), []) @@ -4753,19 +4757,27 @@ def _add_read_slice(self, array: str, node: ast.Subscript, expr: MemletExpr): if not strides: strides = None - if is_index: - tmp = self.sdfg.temp_data_name() - tmp, tmparr = self.sdfg.add_scalar(tmp, arrobj.dtype, arrobj.storage, transient=True) - else: - tmp, tmparr = self.sdfg.add_view(array, - other_subset.size(), - arrobj.dtype, - storage=arrobj.storage, - strides=strides, - find_new_name=True) - self.views[tmp] = (array, - Memlet(f'{array}[{expr.subset}]->{other_subset}', volume=expr.accesses, - wcr=expr.wcr)) + # if is_index: + # tmp = self.sdfg.temp_data_name() + # tmp, tmparr = self.sdfg.add_scalar(tmp, arrobj.dtype, arrobj.storage, transient=True) + # else: + # tmp, tmparr = self.sdfg.add_view(array, + # other_subset.size(), + # arrobj.dtype, + # storage=arrobj.storage, + # strides=strides, + # find_new_name=True) + # self.views[tmp] = (array, + # Memlet(f'{array}[{expr.subset}]->{other_subset}', volume=expr.accesses, + # wcr=expr.wcr)) + tmp, tmparr = self.sdfg.add_view(array, + other_subset.size(), + arrobj.dtype, + storage=arrobj.storage, + strides=strides, + find_new_name=True) + self.views[tmp] = (array, Memlet(f'{array}[{expr.subset}]->{other_subset}', volume=expr.accesses, + wcr=expr.wcr)) self.variables[tmp] = tmp if not isinstance(tmparr, data.View): rnode = self.last_state.add_read(array, debuginfo=self.current_lineinfo) diff --git a/tests/sdfg/schedule_tree/conversion_test.py b/tests/schedule_tree/conversion_test.py similarity index 100% rename from tests/sdfg/schedule_tree/conversion_test.py rename to tests/schedule_tree/conversion_test.py From 0cd779a0fde0a9a62ac9c06a1fce9d28c3a1537c Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Sat, 21 Jan 2023 18:17:38 +0100 Subject: [PATCH 39/98] WIP: WCR support in TaskletNodes and when converting back to SDFG. --- dace/frontend/python/replacements.py | 8 ++++++-- dace/sdfg/analysis/schedule_tree/treenodes.py | 16 ++++++++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index fc6333ccbc..ffa05428c3 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -4625,10 +4625,11 @@ def _op(visitor: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, op1: StringLite @oprepo.replaces('dace.tree.tasklet') -def _tasklet(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, label: StringLiteral, inputs: Dict[StringLiteral, str], outputs: Dict[StringLiteral, str], code: StringLiteral, language: dtypes.Language): +def _tasklet(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, label: StringLiteral, inputs: Dict[StringLiteral, str], outputs: Dict[StringLiteral, str], wcr: Dict[StringLiteral, Any], code: StringLiteral, language: dtypes.Language): label = label.value inputs = {k.value: v for k, v in inputs.items()} outputs = {k.value: v for k, v in outputs.items()} + wcr = {k.value: v for k, v in wcr.items()} code = code.value tasklet = state.add_tasklet(label, inputs.keys(), outputs.keys(), code, language) for conn, name in inputs.items(): @@ -4636,7 +4637,10 @@ def _tasklet(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, label: StringLi state.add_edge(access, None, tasklet, conn, Memlet.from_array(name, sdfg.arrays[name])) for conn, name in outputs.items(): access = state.add_access(name) - state.add_edge(tasklet, conn, access, None, Memlet.from_array(name, sdfg.arrays[name])) + memlet = Memlet.from_array(name, sdfg.arrays[name]) + if conn in wcr: + memlet.wcr = wcr[conn] + state.add_edge(tasklet, conn, access, None, memlet) # Handle scope output for out in outputs.values(): diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 7a692fcb5f..622f56616e 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -386,11 +386,23 @@ def as_string(self, indent: int = 0): def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: in_memlets = ', '.join(f"'{k}': {v}" for k, v in self.in_memlets.items()) - out_memlets = ', '.join(f"'{k}': {v}" for k, v in self.out_memlets.items()) + # out_memlets = ', '.join(f"'{k}': {v}" for k, v in self.out_memlets.items()) + out_memlets_dict = dict() + wcr = dict() + for k, v in self.out_memlets.items(): + if v.wcr: + w = copy.deepcopy(v) + w.wcr = None + out_memlets_dict[k] = w + wcr[k] = v.wcr + else: + out_memlets_dict[k] = v + out_memlets = ', '.join(f"'{k}': {v}" for k, v in out_memlets_dict.items()) + wcr_memlets = ', '.join(f"'{k}': {v}" for k, v in wcr.items()) defined_arrays = defined_arrays or set() string, defined_arrays = self.define_arrays(indent, defined_arrays) code = self.node.code.as_string.replace('\n', '\\n') - return string + indent * INDENTATION + f"dace.tree.tasklet(label='{self.node.label}', inputs={{{in_memlets}}}, outputs={{{out_memlets}}}, code='{code}', language=dace.{self.node.language})", defined_arrays + return string + indent * INDENTATION + f"dace.tree.tasklet(label='{self.node.label}', inputs={{{in_memlets}}}, outputs={{{out_memlets}}}, wcr={{{wcr_memlets}}}, code='{code}', language=dace.{self.node.language})", defined_arrays def is_data_used(self, name: str, include_symbols: bool = False) -> bool: used_data = set([memlet.data for memlet in self.in_memlets.values()]) From da038689d41cefbf5af9b3adfd58dc1034381ffa Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Sat, 21 Jan 2023 18:18:14 +0100 Subject: [PATCH 40/98] Added ScheduleTrree tests: Tasklets --- .../schedule_tree/conversions/tasklet_test.py | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 tests/schedule_tree/conversions/tasklet_test.py diff --git a/tests/schedule_tree/conversions/tasklet_test.py b/tests/schedule_tree/conversions/tasklet_test.py new file mode 100644 index 0000000000..6462dbdd30 --- /dev/null +++ b/tests/schedule_tree/conversions/tasklet_test.py @@ -0,0 +1,104 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace +import numpy as np +from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree, as_sdfg + + +def test_simple_tasklet(): + + @dace.program + def simple_tasklet(A: dace.float32[3, 3]): + ret = dace.float32(0) + with dace.tasklet: + c << A[1, 1] + n << A[0, 1] + s << A[2, 1] + w << A[1, 0] + e << A[1, 2] + out = (c + n + s + w + e) / 5 + out >> ret + return ret + + sdfg_pre = simple_tasklet.to_sdfg() + tree = as_schedule_tree(sdfg_pre) + sdfg_post = as_sdfg(tree) + + rng = np.random.default_rng(42) + A = rng.random((3, 3), dtype=np.float32) + ref = (A[1, 1] + A[0, 1] + A[2, 1] + A[1, 0] + A[1, 2]) / 5 + + val_pre = sdfg_pre(A=A)[0] + val_post = sdfg_post(A=A)[0] + + assert np.allclose(val_pre, ref) + assert np.allclose(val_post, ref) + + +def test_multiple_outputs_tasklet(): + + @dace.program + def multiple_outputs_tasklet(A: dace.float32[3, 3]): + ret = np.empty((2,), dtype=np.float32) + with dace.tasklet: + c << A[1, 1] + n << A[0, 1] + s << A[2, 1] + w << A[1, 0] + e << A[1, 2] + out0 = (c + n + s + w + e) / 5 + out1 = c / 2 + (n + s + w + e) / 2 + out0 >> ret[0] + out1 >> ret[1] + return ret + + sdfg_pre = multiple_outputs_tasklet.to_sdfg() + tree = as_schedule_tree(sdfg_pre) + sdfg_post = as_sdfg(tree) + + rng = np.random.default_rng(42) + A = rng.random((3, 3), dtype=np.float32) + ref = np.empty((2,), dtype=np.float32) + ref[0] = (A[1, 1] + A[0, 1] + A[2, 1] + A[1, 0] + A[1, 2]) / 5 + ref[1] = A[1, 1] / 2 + (A[0, 1] + A[2, 1] + A[1, 0] + A[1, 2]) / 2 + + val_pre = sdfg_pre(A=A) + val_post = sdfg_post(A=A) + + assert np.allclose(val_pre, ref) + assert np.allclose(val_post, ref) + + +def test_simple_wcr_tasklet(): + + @dace.program + def simple_wcr_tasklet(A: dace.float32[3, 3]): + ret = dace.float32(2) + with dace.tasklet: + c << A[1, 1] + n << A[0, 1] + s << A[2, 1] + w << A[1, 0] + e << A[1, 2] + out = (c + n + s + w + e) / 5 + out >> ret(1, lambda x, y: x + y) + return ret + + sdfg_pre = simple_wcr_tasklet.to_sdfg() + tree = as_schedule_tree(sdfg_pre) + sdfg_post = as_sdfg(tree) + + rng = np.random.default_rng(42) + A = rng.random((3, 3), dtype=np.float32) + ref = 2 + (A[1, 1] + A[0, 1] + A[2, 1] + A[1, 0] + A[1, 2]) / 5 + + val_pre = sdfg_pre(A=A)[0] + val_post = sdfg_post(A=A)[0] + + assert np.allclose(val_pre, ref) + assert np.allclose(val_post, ref) + + +if __name__ == "__main__": + # test_simple_tasklet() + # test_multiple_outputs_tasklet() + test_simple_wcr_tasklet() From 33df6a1ba385c50c844f8f07cb40a9bb5452a5df Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Sun, 22 Jan 2023 13:20:11 +0100 Subject: [PATCH 41/98] Always add an access for View/Slices. --- dace/frontend/python/newast.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 38be6f01f2..7dc3aeb6ea 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -4779,11 +4779,12 @@ def _add_read_slice(self, array: str, node: ast.Subscript, expr: MemletExpr): self.views[tmp] = (array, Memlet(f'{array}[{expr.subset}]->{other_subset}', volume=expr.accesses, wcr=expr.wcr)) self.variables[tmp] = tmp - if not isinstance(tmparr, data.View): - rnode = self.last_state.add_read(array, debuginfo=self.current_lineinfo) - wnode = self.last_state.add_write(tmp, debuginfo=self.current_lineinfo) - self.last_state.add_nedge( - rnode, wnode, Memlet(f'{array}[{expr.subset}]->{other_subset}', volume=expr.accesses, wcr=expr.wcr)) + self.last_state.add_access(tmp, debuginfo=self.current_lineinfo) + # if not isinstance(tmparr, data.View): + # rnode = self.last_state.add_read(array, debuginfo=self.current_lineinfo) + # wnode = self.last_state.add_write(tmp, debuginfo=self.current_lineinfo) + # self.last_state.add_nedge( + # rnode, wnode, Memlet(f'{array}[{expr.subset}]->{other_subset}', volume=expr.accesses, wcr=expr.wcr)) return tmp def _parse_subscript_slice(self, From c8c3bffceb2f5453a0f5ca08ae1db6ff84fae5a8 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Sun, 22 Jan 2023 13:38:55 +0100 Subject: [PATCH 42/98] Added support for WCR in both input and output Memlets. --- dace/frontend/python/replacements.py | 25 ++++++++--- dace/sdfg/analysis/schedule_tree/treenodes.py | 44 ++++++++++++------- 2 files changed, 49 insertions(+), 20 deletions(-) diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index ffa05428c3..f8725c54d5 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -4625,21 +4625,36 @@ def _op(visitor: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, op1: StringLite @oprepo.replaces('dace.tree.tasklet') -def _tasklet(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, label: StringLiteral, inputs: Dict[StringLiteral, str], outputs: Dict[StringLiteral, str], wcr: Dict[StringLiteral, Any], code: StringLiteral, language: dtypes.Language): +def _tasklet(pv: 'ProgramVisitor', + sdfg: SDFG, + state: SDFGState, + label: StringLiteral, + inputs: Dict[StringLiteral, str], + inputs_wcr: Dict[StringLiteral, Union[None, Callable[[Any, Any], Any]]], + outputs: Dict[StringLiteral, str], + outputs_wcr: Dict[StringLiteral, Union[None, Callable[[Any, Any], Any]]], + code: StringLiteral, + language: dtypes.Language): + + # Extract strings from StringLiterals label = label.value inputs = {k.value: v for k, v in inputs.items()} + inputs_wcr = {k.value: v for k, v in inputs_wcr.items()} outputs = {k.value: v for k, v in outputs.items()} - wcr = {k.value: v for k, v in wcr.items()} + outputs_wcr = {k.value: v for k, v in outputs_wcr.items()} code = code.value + + # Create Tasklet tasklet = state.add_tasklet(label, inputs.keys(), outputs.keys(), code, language) for conn, name in inputs.items(): access = state.add_access(name) - state.add_edge(access, None, tasklet, conn, Memlet.from_array(name, sdfg.arrays[name])) + memlet = Memlet.from_array(name, sdfg.arrays[name]) + memlet.wcr = inputs_wcr[conn] + state.add_edge(access, None, tasklet, conn, memlet) for conn, name in outputs.items(): access = state.add_access(name) memlet = Memlet.from_array(name, sdfg.arrays[name]) - if conn in wcr: - memlet.wcr = wcr[conn] + memlet.wcr = outputs_wcr[conn] state.add_edge(tasklet, conn, access, None, memlet) # Handle scope output diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 622f56616e..8ebb78af21 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -373,6 +373,26 @@ def as_string(self, indent: int = 0): return result + super().as_string(indent) +def _memlets_as_string(d: Dict[str, Memlet]) -> Tuple[str, str]: + + memlets_dict = dict() + wcr_dict = dict() + + for k, v in d.items(): + + assert v.other_subset == None + w = copy.deepcopy(v) + w.wcr = None + + memlets_dict[k] = w + wcr_dict[k] = v.wcr + + memlets_str = ', '.join(f"'{k}': {v}" for k, v in memlets_dict.items()) + wcr_str = ', '.join(f"'{k}': {v}" for k, v in wcr_dict.items()) + + return memlets_str, wcr_str + + @dataclass class TaskletNode(ScheduleTreeNode): node: nodes.Tasklet @@ -385,24 +405,18 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'{out_memlets} = tasklet({in_memlets})' def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: - in_memlets = ', '.join(f"'{k}': {v}" for k, v in self.in_memlets.items()) - # out_memlets = ', '.join(f"'{k}': {v}" for k, v in self.out_memlets.items()) - out_memlets_dict = dict() - wcr = dict() - for k, v in self.out_memlets.items(): - if v.wcr: - w = copy.deepcopy(v) - w.wcr = None - out_memlets_dict[k] = w - wcr[k] = v.wcr - else: - out_memlets_dict[k] = v - out_memlets = ', '.join(f"'{k}': {v}" for k, v in out_memlets_dict.items()) - wcr_memlets = ', '.join(f"'{k}': {v}" for k, v in wcr.items()) + in_memlets, in_wcr = _memlets_as_string(self.in_memlets) + out_memlets, out_wcr = _memlets_as_string(self.out_memlets) defined_arrays = defined_arrays or set() string, defined_arrays = self.define_arrays(indent, defined_arrays) code = self.node.code.as_string.replace('\n', '\\n') - return string + indent * INDENTATION + f"dace.tree.tasklet(label='{self.node.label}', inputs={{{in_memlets}}}, outputs={{{out_memlets}}}, wcr={{{wcr_memlets}}}, code='{code}', language=dace.{self.node.language})", defined_arrays + return ( + string + indent * INDENTATION + (f"dace.tree.tasklet(label='{self.node.label}', inputs={{{in_memlets}}}, " + f"inputs_wcr={{{in_wcr}}}, outputs={{{out_memlets}}}, " + f"outputs_wcr={{{out_wcr}}}, code='{code}', " + f"language=dace.{self.node.language})"), + defined_arrays + ) def is_data_used(self, name: str, include_symbols: bool = False) -> bool: used_data = set([memlet.data for memlet in self.in_memlets.values()]) From db2d8aa3735f7a6d0baa2a6c9ee90d2729b2cd36 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Sun, 22 Jan 2023 13:39:22 +0100 Subject: [PATCH 43/98] Added more tests. --- .../schedule_tree/conversions/tasklet_test.py | 36 +++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/tests/schedule_tree/conversions/tasklet_test.py b/tests/schedule_tree/conversions/tasklet_test.py index 6462dbdd30..b383dfb46d 100644 --- a/tests/schedule_tree/conversions/tasklet_test.py +++ b/tests/schedule_tree/conversions/tasklet_test.py @@ -98,7 +98,39 @@ def simple_wcr_tasklet(A: dace.float32[3, 3]): assert np.allclose(val_post, ref) +def test_simple_wcr_tasklet2(): + + @dace.program + def simple_wcr_tasklet2(A: dace.float32[3, 3]): + ret = dace.float32(2) + with dace.tasklet: + c << A[1, 1] + n << A[0, 1] + s << A[2, 1] + w << A[1, 0] + e << A[1, 2] + inp << ret(1, lambda x, y: x + y) + out = (c + n + s + w + e) / 5 + out >> ret(1, lambda x, y: x + y) + return ret + + sdfg_pre = simple_wcr_tasklet2.to_sdfg() + tree = as_schedule_tree(sdfg_pre) + sdfg_post = as_sdfg(tree) + + rng = np.random.default_rng(42) + A = rng.random((3, 3), dtype=np.float32) + ref = 2 + (A[1, 1] + A[0, 1] + A[2, 1] + A[1, 0] + A[1, 2]) / 5 + + val_pre = sdfg_pre(A=A)[0] + val_post = sdfg_post(A=A)[0] + + assert np.allclose(val_pre, ref) + assert np.allclose(val_post, ref) + + if __name__ == "__main__": - # test_simple_tasklet() - # test_multiple_outputs_tasklet() + test_simple_tasklet() + test_multiple_outputs_tasklet() test_simple_wcr_tasklet() + test_simple_wcr_tasklet2() From 8ffcf8f82bc946f4ab53e7b426722acd519f1801 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Sun, 22 Jan 2023 14:27:28 +0100 Subject: [PATCH 44/98] Allow Views to have only incoming or outgoing edges. --- dace/sdfg/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 03b69adfd6..c33578f396 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -810,7 +810,7 @@ def get_view_edge(state: SDFGState, view: nd.AccessNode) -> gr.MultiConnectorEdg out_edges = state.out_edges(view) # Invalid case: No data to view - if len(in_edges) == 0 or len(out_edges) == 0: + if len(in_edges) == 0 and len(out_edges) == 0: return None # If there is one edge (in/out) that leads (via memlet path) to an access From b3bd6b2c7d54564c2df15fac76427d3ee1976881 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Sun, 22 Jan 2023 15:25:27 +0100 Subject: [PATCH 45/98] Reverted all changes regarding Views. --- dace/frontend/python/newast.py | 51 +++++++++++++--------------------- 1 file changed, 19 insertions(+), 32 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 7dc3aeb6ea..87f91bc3df 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -1198,9 +1198,6 @@ def parse_program(self, program: ast.FunctionDef, is_tasklet: bool = False): for vname, arrname in self.variables.items(): if vname.startswith('__return'): if isinstance(self.sdfg.arrays[arrname], data.View): - # If it is a registered view, ignore. - if vname in self.views and self.views[vname][0].startswith('__return'): - continue # In case of a view, make a copy # NOTE: If we are at the top level SDFG (not always clear), # and it is a View of an input array, can we return a NumPy @@ -1236,8 +1233,7 @@ def parse_program(self, program: ast.FunctionDef, is_tasklet: bool = False): for arrname, arr in self.sdfg.arrays.items(): # Return values become non-transient (accessible by the outside) - if arrname.startswith('__return') and not ( - vname in self.views and self.views[vname][0].startswith('__return')): + if arrname.startswith('__return'): arr.transient = False self.outputs[arrname] = (None, Memlet.from_array(arrname, arr), []) @@ -4757,34 +4753,25 @@ def _add_read_slice(self, array: str, node: ast.Subscript, expr: MemletExpr): if not strides: strides = None - # if is_index: - # tmp = self.sdfg.temp_data_name() - # tmp, tmparr = self.sdfg.add_scalar(tmp, arrobj.dtype, arrobj.storage, transient=True) - # else: - # tmp, tmparr = self.sdfg.add_view(array, - # other_subset.size(), - # arrobj.dtype, - # storage=arrobj.storage, - # strides=strides, - # find_new_name=True) - # self.views[tmp] = (array, - # Memlet(f'{array}[{expr.subset}]->{other_subset}', volume=expr.accesses, - # wcr=expr.wcr)) - tmp, tmparr = self.sdfg.add_view(array, - other_subset.size(), - arrobj.dtype, - storage=arrobj.storage, - strides=strides, - find_new_name=True) - self.views[tmp] = (array, Memlet(f'{array}[{expr.subset}]->{other_subset}', volume=expr.accesses, - wcr=expr.wcr)) + if is_index: + tmp = self.sdfg.temp_data_name() + tmp, tmparr = self.sdfg.add_scalar(tmp, arrobj.dtype, arrobj.storage, transient=True) + else: + tmp, tmparr = self.sdfg.add_view(array, + other_subset.size(), + arrobj.dtype, + storage=arrobj.storage, + strides=strides, + find_new_name=True) + self.views[tmp] = (array, + Memlet(f'{array}[{expr.subset}]->{other_subset}', volume=expr.accesses, + wcr=expr.wcr)) self.variables[tmp] = tmp - self.last_state.add_access(tmp, debuginfo=self.current_lineinfo) - # if not isinstance(tmparr, data.View): - # rnode = self.last_state.add_read(array, debuginfo=self.current_lineinfo) - # wnode = self.last_state.add_write(tmp, debuginfo=self.current_lineinfo) - # self.last_state.add_nedge( - # rnode, wnode, Memlet(f'{array}[{expr.subset}]->{other_subset}', volume=expr.accesses, wcr=expr.wcr)) + if not isinstance(tmparr, data.View): + rnode = self.last_state.add_read(array, debuginfo=self.current_lineinfo) + wnode = self.last_state.add_write(tmp, debuginfo=self.current_lineinfo) + self.last_state.add_nedge( + rnode, wnode, Memlet(f'{array}[{expr.subset}]->{other_subset}', volume=expr.accesses, wcr=expr.wcr)) return tmp def _parse_subscript_slice(self, From 1b2a5e6857551762ffdb2b2bb6188547b9fce5a7 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Sun, 22 Jan 2023 15:26:17 +0100 Subject: [PATCH 46/98] TaskletNode.as_python now uses explicit dataflow syntax. --- dace/sdfg/analysis/schedule_tree/treenodes.py | 43 +++++++------------ 1 file changed, 15 insertions(+), 28 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 8ebb78af21..33c9085b46 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -9,6 +9,7 @@ from dace.sdfg.state import SDFGState from dace.symbolic import symbol from dace.memlet import Memlet +import math from typing import Dict, List, Optional, Set, Tuple, Union INDENTATION = ' ' @@ -373,24 +374,12 @@ def as_string(self, indent: int = 0): return result + super().as_string(indent) -def _memlets_as_string(d: Dict[str, Memlet]) -> Tuple[str, str]: - - memlets_dict = dict() - wcr_dict = dict() - - for k, v in d.items(): - - assert v.other_subset == None - w = copy.deepcopy(v) - w.wcr = None - - memlets_dict[k] = w - wcr_dict[k] = v.wcr - - memlets_str = ', '.join(f"'{k}': {v}" for k, v in memlets_dict.items()) - wcr_str = ', '.join(f"'{k}': {v}" for k, v in wcr_dict.items()) - - return memlets_str, wcr_str +def _memlet_to_str(memlet: Memlet) -> str: + assert memlet.other_subset == None + wcr = "" + if memlet.wcr: + wcr = f"({math.prod(memlet.subset.size())}, {memlet.wcr})" + return f"{memlet.data}{wcr}[{memlet.subset}]" @dataclass @@ -405,18 +394,16 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'{out_memlets} = tasklet({in_memlets})' def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: - in_memlets, in_wcr = _memlets_as_string(self.in_memlets) - out_memlets, out_wcr = _memlets_as_string(self.out_memlets) + explicit_dataflow = indent * INDENTATION + "with dace.tasklet:\n" + for conn, memlet in self.in_memlets.items(): + explicit_dataflow += (indent + 1) * INDENTATION + f"{conn} << {_memlet_to_str(memlet)}\n" + for conn, memlet in self.out_memlets.items(): + explicit_dataflow += (indent + 1) * INDENTATION + f"{conn} >> {_memlet_to_str(memlet)}\n" + code = self.node.code.as_string.replace('\n', f"\n{(indent + 1) * INDENTATION}") + explicit_dataflow += (indent + 1) * INDENTATION + code defined_arrays = defined_arrays or set() string, defined_arrays = self.define_arrays(indent, defined_arrays) - code = self.node.code.as_string.replace('\n', '\\n') - return ( - string + indent * INDENTATION + (f"dace.tree.tasklet(label='{self.node.label}', inputs={{{in_memlets}}}, " - f"inputs_wcr={{{in_wcr}}}, outputs={{{out_memlets}}}, " - f"outputs_wcr={{{out_wcr}}}, code='{code}', " - f"language=dace.{self.node.language})"), - defined_arrays - ) + return string + explicit_dataflow, defined_arrays def is_data_used(self, name: str, include_symbols: bool = False) -> bool: used_data = set([memlet.data for memlet in self.in_memlets.values()]) From bbdbdcede25938185f3a8bbd60bef36809aa38c1 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Sun, 22 Jan 2023 15:33:11 +0100 Subject: [PATCH 47/98] Added documentation --- tests/schedule_tree/conversions/tasklet_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/schedule_tree/conversions/tasklet_test.py b/tests/schedule_tree/conversions/tasklet_test.py index b383dfb46d..bc527fa05c 100644 --- a/tests/schedule_tree/conversions/tasklet_test.py +++ b/tests/schedule_tree/conversions/tasklet_test.py @@ -1,10 +1,12 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" Tests conversion of Tasklets from SDFG to ScheduleTree and back. """ import dace import numpy as np from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree, as_sdfg def test_simple_tasklet(): + """ Tests a Tasklet with a single (non-WCR) output. """ @dace.program def simple_tasklet(A: dace.float32[3, 3]): @@ -35,6 +37,7 @@ def simple_tasklet(A: dace.float32[3, 3]): def test_multiple_outputs_tasklet(): + """ Tests a Tasklet with multiple (non-WCR) outputs. """ @dace.program def multiple_outputs_tasklet(A: dace.float32[3, 3]): @@ -69,6 +72,7 @@ def multiple_outputs_tasklet(A: dace.float32[3, 3]): def test_simple_wcr_tasklet(): + """ Tests a Tasklet with a single WCR output. """ @dace.program def simple_wcr_tasklet(A: dace.float32[3, 3]): @@ -99,6 +103,7 @@ def simple_wcr_tasklet(A: dace.float32[3, 3]): def test_simple_wcr_tasklet2(): + """ Tests a tasklet with a single WCR output. The output is also (fake) input with WCR. """ @dace.program def simple_wcr_tasklet2(A: dace.float32[3, 3]): From 9f1fdee36697e7c6431daa6cb548d515a24c5ed9 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Sun, 22 Jan 2023 15:59:49 +0100 Subject: [PATCH 48/98] De-alias the SDFG just before populating the containers. --- dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 0ed1ee779b..c6bf8c25e9 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -416,7 +416,6 @@ def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) if toplevel: # Top-level SDFG preparation (only perform once) # Handle name collisions (in arrays, state labels, symbols) - dealias_sdfg(sdfg) remove_name_collisions(sdfg) ############################# @@ -502,6 +501,7 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche result = tn.ScheduleTreeScope(sdfg=sdfg, top_level=True, children=totree(cfg)) if toplevel: + dealias_sdfg(sdfg) populate_containers(result) # Clean up tree From 3cfb7082be359808641b947655cfdc9e51ecf319 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Sun, 22 Jan 2023 18:12:25 +0100 Subject: [PATCH 49/98] Fixes to SDFGs dealiazing (multiple NestedSDFG connectors from/to the same parent data). --- .../analysis/schedule_tree/sdfg_to_tree.py | 53 ++++++++++++++++++- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index c6bf8c25e9..8b4020c933 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -2,7 +2,7 @@ import copy from typing import Dict, List, Set import dace -from dace import symbolic, data +from dace import data, subsets, symbolic from dace.codegen import control_flow as cf from dace.sdfg.sdfg import InterstateEdge, SDFG from dace.sdfg.state import SDFGState @@ -24,6 +24,9 @@ def dealias_sdfg(sdfg: SDFG): continue replacements: Dict[str, str] = {} + inv_replacements: Dict[str, List[str]] = {} + parent_edges: Dict[str, Memlet] = {} + to_unsqueeze: Set[str] = set() parent_sdfg = nsdfg.parent_sdfg parent_state = nsdfg.parent @@ -37,8 +40,54 @@ def dealias_sdfg(sdfg: SDFG): assert parent_name in parent_sdfg.arrays if name != parent_name: replacements[name] = parent_name + parent_edges[name] = edge + if parent_name in inv_replacements: + inv_replacements[parent_name].append(name) + to_unsqueeze.add(parent_name) + else: + inv_replacements[parent_name] = [name] break + if to_unsqueeze: + for parent_name in to_unsqueeze: + parent_arr = parent_sdfg.arrays[parent_name] + if isinstance(parent_arr, data.View): + parent_arr = data.Array(parent_arr.dtype, parent_arr.shape, parent_arr.transient, + parent_arr.allow_conflicts, parent_arr.storage, parent_arr.location, + parent_arr.strides, parent_arr.offset, parent_arr.may_alias, + parent_arr.lifetime, parent_arr.alignment, parent_arr.debuginfo, + parent_arr.total_size, parent_arr.start_offset, parent_arr.optional, + parent_arr.pool) + child_names = inv_replacements[parent_name] + for name in child_names: + child_arr = copy.deepcopy(parent_arr) + child_arr.transient = False + nsdfg.arrays[name] = child_arr + for state in nsdfg.states(): + for e in state.edges(): + if e.data.data in child_names: + e.data = unsqueeze_memlet(e.data, parent_edges[e.data.data].data) + for e in nsdfg.edges(): + repl_dict = dict() + syms = e.data.read_symbols() + for memlet in e.data.get_read_memlets(nsdfg.arrays): + if memlet.data in child_names: + repl_dict[str(memlet)] = unsqueeze_memlet(memlet, parent_edges[memlet.data].data) + if memlet.data in syms: + syms.remove(memlet.data) + for s in syms: + if s in parent_edges: + repl_dict[s] = str(parent_edges[s].data) + e.data.replace_dict(repl_dict) + for name in child_names: + edge = parent_edges[name] + for e in parent_state.memlet_tree(edge): + if e.data.data == parent_name: + e.data.subset = subsets.Range.from_array(parent_arr) + else: + e.data.other_subset = subsets.Range.from_array(parent_arr) + + if replacements: nsdfg.replace_dict(replacements) parent_node.in_connectors = {replacements[c] if c in replacements else c: t for c, t in parent_node.in_connectors.items()} @@ -415,6 +464,7 @@ def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) xfh.replace_code_to_code_edges(sdfg) if toplevel: # Top-level SDFG preparation (only perform once) + dealias_sdfg(sdfg) # Handle name collisions (in arrays, state labels, symbols) remove_name_collisions(sdfg) @@ -501,7 +551,6 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche result = tn.ScheduleTreeScope(sdfg=sdfg, top_level=True, children=totree(cfg)) if toplevel: - dealias_sdfg(sdfg) populate_containers(result) # Clean up tree From d68a016aabf93f5d2cf2e875e273fb7a838f6322 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Sun, 22 Jan 2023 18:12:58 +0100 Subject: [PATCH 50/98] WIP: `as_python` support for ViewNodes. --- dace/sdfg/analysis/schedule_tree/treenodes.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 33c9085b46..daa2b4fd56 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -9,7 +9,7 @@ from dace.sdfg.state import SDFGState from dace.symbolic import symbol from dace.memlet import Memlet -import math +from functools import reduce from typing import Dict, List, Optional, Set, Tuple, Union INDENTATION = ' ' @@ -378,7 +378,7 @@ def _memlet_to_str(memlet: Memlet) -> str: assert memlet.other_subset == None wcr = "" if memlet.wcr: - wcr = f"({math.prod(memlet.subset.size())}, {memlet.wcr})" + wcr = f"({reduce(lambda x, y: x * y, memlet.subset.size())}, {memlet.wcr})" return f"{memlet.data}{wcr}[{memlet.subset}]" @@ -523,6 +523,15 @@ class ViewNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target} = view {self.memlet} as {self.view_desc.shape}' + + def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: + defined_arrays = defined_arrays or set() + string, defined_arrays = self.define_arrays(indent, defined_arrays) + return string + indent * INDENTATION + f"{self.target} = {self.memlet}", defined_arrays + + def is_data_used(self, name: str) -> bool: + # NOTE: View data must not be considered used + return name is self.memlet.data @dataclass From f930a949ae30c848189d5f38a5fc17309c26778b Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Sun, 22 Jan 2023 18:13:24 +0100 Subject: [PATCH 51/98] Added ScheduleTree conversion tests for Map scopes. --- tests/schedule_tree/conversions/map_test.py | 245 ++++++++++++++++++++ 1 file changed, 245 insertions(+) create mode 100644 tests/schedule_tree/conversions/map_test.py diff --git a/tests/schedule_tree/conversions/map_test.py b/tests/schedule_tree/conversions/map_test.py new file mode 100644 index 0000000000..540aabfbbc --- /dev/null +++ b/tests/schedule_tree/conversions/map_test.py @@ -0,0 +1,245 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" Tests conversion of Map scopes from SDFG to ScheduleTree and back. """ +import dace +import numpy as np +from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree, as_sdfg + + +def test_simple_map(): + """ Tests a Map Scope with a single (non-WCR) output. """ + + M, N = (dace.symbol(s) for s in ('M', 'N')) + + @dace.program + def simple_map(A: dace.float32[M, N]): + B = np.zeros((M, N), dtype=A.dtype) + for i, j in dace.map[1:M-1, 1:N-1]: + with dace.tasklet: + c << A[i, j] + n << A[i-1, j] + s << A[i+1, j] + w << A[i, j-1] + e << A[i, j+1] + out = (c + n + s + w + e) / 5 + out >> B[i, j] + return B + + sdfg_pre = simple_map.to_sdfg() + tree = as_schedule_tree(sdfg_pre) + print(tree.as_string()) + sdfg_post = as_sdfg(tree) + + rng = np.random.default_rng(42) + A = rng.random((20, 20), dtype=np.float32) + ref = np.zeros_like(A) + for i, j in dace.map[1:19, 1:19]: + ref[i, j] = (A[i, j] + A[i-1, j] + A[i+1, j] + A[i, j-1] + A[i, j+1]) / 5 + + val_pre = sdfg_pre(A=A, M=20, N=20) + val_post = sdfg_post(A=A, M=20, N=20) + + assert np.allclose(val_pre, ref) + assert np.allclose(val_post, ref) + + +def test_multiple_outputs_map(): + """ Tests a Map Scope with multiple (non-WCR) outputs. """ + + M, N = (dace.symbol(s) for s in ('M', 'N')) + + @dace.program + def multiple_outputs_map(A: dace.float32[M, N]): + B = np.zeros((2, M, N), dtype=A.dtype) + for i, j in dace.map[1:M-1, 1:N-1]: + with dace.tasklet: + c << A[i, j] + n << A[i-1, j] + s << A[i+1, j] + w << A[i, j-1] + e << A[i, j+1] + out0 = (c + n + s + w + e) / 5 + out1 = c / 2 + (n + s + w + e) / 2 + out0 >> B[0, i, j] + out1 >> B[1, i, j] + return B + + sdfg_pre = multiple_outputs_map.to_sdfg() + tree = as_schedule_tree(sdfg_pre) + sdfg_post = as_sdfg(tree) + + rng = np.random.default_rng(42) + A = rng.random((20, 20), dtype=np.float32) + ref = np.zeros_like(A, shape=(2, 20, 20)) + for i, j in dace.map[1:19, 1:19]: + ref[0, i, j] = (A[i, j] + A[i-1, j] + A[i+1, j] + A[i, j-1] + A[i, j+1]) / 5 + ref[1, i, j] = A[i, j] / 2 + (A[i-1, j] + A[i+1, j] + A[i, j-1] + A[i, j+1]) / 2 + + val_pre = sdfg_pre(A=A, M=20, N=20) + val_post = sdfg_post(A=A, M=20, N=20) + + assert np.allclose(val_pre, ref) + assert np.allclose(val_post, ref) + + +def test_simple_wcr_map(): + """ Tests a Map Scope with a single WCR output. """ + + M, N = (dace.symbol(s) for s in ('M', 'N')) + + @dace.program + def simple_wcr_map(A: dace.float32[M, N]): + ret = dace.float32(0) + for i, j in dace.map[1:M-1, 1:N-1]: + with dace.tasklet: + c << A[i, j] + n << A[i-1, j] + s << A[i+1, j] + w << A[i, j-1] + e << A[i, j+1] + out = (c + n + s + w + e) / 5 + out >> ret(1, lambda x, y: x + y) + return ret + + sdfg_pre = simple_wcr_map.to_sdfg() + tree = as_schedule_tree(sdfg_pre) + sdfg_post = as_sdfg(tree) + + rng = np.random.default_rng(42) + A = rng.random((20, 20), dtype=np.float32) + ref = np.float32(0) + for i, j in dace.map[1:19, 1:19]: + ref += (A[i, j] + A[i-1, j] + A[i+1, j] + A[i, j-1] + A[i, j+1]) / 5 + + val_pre = sdfg_pre(A=A, M=20, N=20)[0] + val_post = sdfg_post(A=A, M=20, N=20)[0] + + assert np.allclose(val_pre, ref) + assert np.allclose(val_post, ref) + + +def test_simple_wcr_map2(): + """ Tests a Map Scope with a single WCR output. The output is also (fake) input with WCR. """ + + M, N = (dace.symbol(s) for s in ('M', 'N')) + + @dace.program + def simple_wcr_map2(A: dace.float32[M, N]): + ret = dace.float32(0) + for i, j in dace.map[1:M-1, 1:N-1]: + with dace.tasklet: + c << A[i, j] + n << A[i-1, j] + s << A[i+1, j] + w << A[i, j-1] + e << A[i, j+1] + inp << ret(1, lambda x, y: x + y) + out = (c + n + s + w + e) / 5 + out >> ret(1, lambda x, y: x + y) + return ret + + sdfg_pre = simple_wcr_map2.to_sdfg() + tree = as_schedule_tree(sdfg_pre) + sdfg_post = as_sdfg(tree) + + rng = np.random.default_rng(42) + A = rng.random((20, 20), dtype=np.float32) + ref = np.float32(0) + for i, j in dace.map[1:19, 1:19]: + ref += (A[i, j] + A[i-1, j] + A[i+1, j] + A[i, j-1] + A[i, j+1]) / 5 + + val_pre = sdfg_pre(A=A, M=20, N=20)[0] + val_post = sdfg_post(A=A, M=20, N=20)[0] + + assert np.allclose(val_pre, ref) + assert np.allclose(val_post, ref) + + +def test_multiple_outputs_mixed_map(): + """ Tests a Map Scope with multiple (WCR and non-WCR) outputs. """ + + M, N = (dace.symbol(s) for s in ('M', 'N')) + + @dace.program + def multiple_outputs_map(A: dace.float32[M, N]): + B = np.zeros((M, N), dtype=A.dtype) + ret = np.float32(0) + for i, j in dace.map[1:M-1, 1:N-1]: + with dace.tasklet: + c << A[i, j] + n << A[i-1, j] + s << A[i+1, j] + w << A[i, j-1] + e << A[i, j+1] + out0 = (c + n + s + w + e) / 5 + out1 = out0 + out0 >> B[i, j] + out1 >> ret(1, lambda x, y: x + y) + return B, ret + + sdfg_pre = multiple_outputs_map.to_sdfg() + tree = as_schedule_tree(sdfg_pre) + sdfg_post = as_sdfg(tree) + + rng = np.random.default_rng(42) + A = rng.random((20, 20), dtype=np.float32) + ref0 = np.zeros_like(A, shape=(20, 20)) + ref1 = np.float32(0) + for i, j in dace.map[1:19, 1:19]: + ref0[i, j] = (A[i, j] + A[i-1, j] + A[i+1, j] + A[i, j-1] + A[i, j+1]) / 5 + ref1 += ref0[i, j] + + val_pre = sdfg_pre(A=A, M=20, N=20) + val_post = sdfg_post(A=A, M=20, N=20) + + assert np.allclose(val_pre[0], ref0) + assert np.allclose(val_pre[1], ref1) + assert np.allclose(val_post[0], ref0) + assert np.allclose(val_post[1], ref1) + + +# NOTE: This fails due to input connector appearing to be written (issue with Views) +def test_nested_simple_map(): + """ Tests a nested Map Scope with a single (non-WCR) output. """ + + M, N = (dace.symbol(s) for s in ('M', 'N')) + + @dace.program + def nested_simple_map(A: dace.float32[M, N]): + B = np.zeros((M, N), dtype=A.dtype) + for i, j in dace.map[1:M-2:2, 1:N-2:2]: + inA = A[i-1:i+3, j-1:j+3] + for k, l in dace.map[0:2, 0:2]: + with dace.tasklet: + c << inA[k+1, l+1] + n << inA[k, l+1] + s << inA[k+2, l+1] + w << inA[k+1, l] + e << inA[k+1, l+2] + out = (c + n + s + w + e) / 5 + out >> B[i+k, j+l] + return B + + sdfg_pre = nested_simple_map.to_sdfg() + tree = as_schedule_tree(sdfg_pre) + sdfg_post = as_sdfg(tree) + + rng = np.random.default_rng(42) + A = rng.random((20, 20), dtype=np.float32) + ref = np.zeros_like(A) + for i, j in dace.map[1:19, 1:19]: + ref[i, j] = (A[i, j] + A[i-1, j] + A[i+1, j] + A[i, j-1] + A[i, j+1]) / 5 + + val_pre = sdfg_pre(A=A, M=20, N=20) + val_post = sdfg_post(A=A, M=20, N=20) + + assert np.allclose(val_pre, ref) + assert np.allclose(val_post, ref) + + +if __name__ == "__main__": + test_simple_map() + test_multiple_outputs_map() + test_simple_wcr_map() + test_simple_wcr_map2() + test_multiple_outputs_mixed_map() + test_nested_simple_map() From 57d42f7ccffdbb646db3ca964b57038c26a2504a Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Mon, 23 Jan 2023 12:12:21 +0100 Subject: [PATCH 52/98] Ignore reassignment of Views to the same Array slice. --- dace/frontend/python/newast.py | 6 ++++++ tests/schedule_tree/conversions/map_test.py | 1 - 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 87f91bc3df..f371802442 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -3116,6 +3116,12 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): if (not is_return and isinstance(target, ast.Name) and true_name and not op and not isinstance(true_array, data.Scalar) and not (true_array.shape == (1, ))): + if true_name in self.views: + if result in self.sdfg.arrays and self.views[true_name] == ( + result, Memlet.from_array(result, self.sdfg.arrays[result])): + continue + else: + raise DaceSyntaxError(self, target, 'Cannot reassign View "{}"'.format(name)) if (isinstance(result, str) and result in self.sdfg.arrays and self.sdfg.arrays[result].is_equivalent(true_array)): # Skip error if the arrays are defined exactly in the same way. diff --git a/tests/schedule_tree/conversions/map_test.py b/tests/schedule_tree/conversions/map_test.py index 540aabfbbc..8761d82c3c 100644 --- a/tests/schedule_tree/conversions/map_test.py +++ b/tests/schedule_tree/conversions/map_test.py @@ -197,7 +197,6 @@ def multiple_outputs_map(A: dace.float32[M, N]): assert np.allclose(val_post[1], ref1) -# NOTE: This fails due to input connector appearing to be written (issue with Views) def test_nested_simple_map(): """ Tests a nested Map Scope with a single (non-WCR) output. """ From 731a0bf55f9e2047f790d5f3ecbff06c3df64745 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Mon, 23 Jan 2023 14:04:21 +0100 Subject: [PATCH 53/98] Support codes with a single statements. --- dace/properties.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/dace/properties.py b/dace/properties.py index 0e8a010d71..b06bd71959 100644 --- a/dace/properties.py +++ b/dace/properties.py @@ -987,8 +987,11 @@ def get_free_symbols(self, defined_syms: Set[str] = None) -> Set[str]: if self.language == dace.dtypes.Language.Python: visitor = TaskletFreeSymbolVisitor(defined_syms) if self.code: - for stmt in self.code: - visitor.visit(stmt) + if isinstance(self.code, list): + for stmt in self.code: + visitor.visit(stmt) + else: + visitor.visit(self.code) return visitor.free_symbols return set() From 6dbbd3f4ed89a2dc78f4ad12678f47244835c434 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Mon, 23 Jan 2023 14:04:48 +0100 Subject: [PATCH 54/98] Support library nodes with no connectors. --- dace/frontend/python/replacements.py | 31 ++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index f8725c54d5..bb24a303d8 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -4668,18 +4668,37 @@ def _tasklet(pv: 'ProgramVisitor', @oprepo.replaces('dace.tree.library') -def _library(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, ltype: type, label: StringLiteral, inputs: Dict[StringLiteral, str], outputs: Dict[StringLiteral, str], **kwargs): +def _library(pv: 'ProgramVisitor', + sdfg: SDFG, + state: SDFGState, + ltype: type, + label: StringLiteral, + inputs: Union[Dict[StringLiteral, str], Set[str]], + outputs: Union[Dict[StringLiteral, str], Set[str]], + **kwargs): + + # Extract strings from StringLiterals label = label.value - inputs = {k.value: v for k, v in inputs.items()} - outputs = {k.value: v for k, v in outputs.items()} + if isinstance(inputs, dict): + inputs = {k.value: v for k, v in inputs.items()} + else: + inputs = {i: v for i, v in enumerate(inputs)} + if isinstance(outputs, dict): + outputs = {k.value: v for k, v in outputs.items()} + else: + outputs = {i: v for i, v in enumerate(outputs)} + + # Create LibraryNode tasklet = ltype(label, **kwargs) state.add_node(tasklet) - for conn, name in inputs.items(): + for k, name in inputs.items(): access = state.add_access(name) + conn = k if isinstance(k, str) else None state.add_edge(access, None, tasklet, conn, Memlet.from_array(name, sdfg.arrays[name])) - for conn, name in outputs.items(): + for k, name in outputs.items(): access = state.add_access(name) memlet = Memlet.from_array(name, sdfg.arrays[name]) + conn = k if isinstance(k, str) else None state.add_edge(tasklet, conn, access, None, memlet) # Handle scope output @@ -4693,7 +4712,7 @@ def _library(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, ltype: type, la @oprepo.replaces('dace.tree.copy') -def _tasklet(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, src: str, dst: str, wcr: str = None): +def _copy(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, src: str, dst: str, wcr: str = None): src_access = state.add_access(src) dst_access = state.add_access(dst) state.add_nedge(src_access, dst_access, Memlet.from_array(dst, sdfg.arrays[dst], wcr=None)) From cf68922428b9f431ec093577b7fac2781afd342d Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Mon, 23 Jan 2023 14:05:33 +0100 Subject: [PATCH 55/98] Reduce node must have `name` as a parameter (like other LibraryNodes). --- dace/libraries/standard/nodes/reduce.py | 3 ++- dace/sdfg/analysis/schedule_tree/replacements.py | 12 ++++++++++++ dace/sdfg/state.py | 2 +- 3 files changed, 15 insertions(+), 2 deletions(-) create mode 100644 dace/sdfg/analysis/schedule_tree/replacements.py diff --git a/dace/libraries/standard/nodes/reduce.py b/dace/libraries/standard/nodes/reduce.py index 8423022fae..5807d34ad6 100644 --- a/dace/libraries/standard/nodes/reduce.py +++ b/dace/libraries/standard/nodes/reduce.py @@ -1562,13 +1562,14 @@ class Reduce(dace.sdfg.nodes.LibraryNode): identity = Property(allow_none=True) def __init__(self, + name, wcr='lambda a, b: a', axes=None, identity=None, schedule=dtypes.ScheduleType.Default, debuginfo=None, **kwargs): - super().__init__(name='Reduce', **kwargs) + super().__init__(name=name, **kwargs) self.wcr = wcr self.axes = axes self.identity = identity diff --git a/dace/sdfg/analysis/schedule_tree/replacements.py b/dace/sdfg/analysis/schedule_tree/replacements.py new file mode 100644 index 0000000000..1df8fb4e6a --- /dev/null +++ b/dace/sdfg/analysis/schedule_tree/replacements.py @@ -0,0 +1,12 @@ +# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +from dace import SDFG, SDFGState +from dace.frontend.common import op_repository as oprepo +from typing import Tuple + + +@oprepo.replaces('dace.tree.library') +def _library(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, ltype: str, label: str, inputs: Tuple[str], outputs: Tuple[str]): + print(ltype) + print(label) + print(inputs) + print(outputs) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 4ef09012fe..f6912d1e87 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -1390,7 +1390,7 @@ def add_reduce( """ import dace.libraries.standard as stdlib # Avoid import loop debuginfo = _getdebuginfo(debuginfo or self._default_lineinfo) - result = stdlib.Reduce(wcr, axes, identity, schedule=schedule, debuginfo=debuginfo) + result = stdlib.Reduce('Reduce', wcr, axes, identity, schedule=schedule, debuginfo=debuginfo) self.add_node(result) return result From 81196bdc7f775a5b9e0ff0b13be917de8478d7fd Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Mon, 23 Jan 2023 14:06:11 +0100 Subject: [PATCH 56/98] IfScope should check conditions for potentially data used. --- dace/sdfg/analysis/schedule_tree/treenodes.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index daa2b4fd56..1a56334248 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -30,7 +30,7 @@ def as_string(self, indent: int = 0): def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: string, defined_arrays = self.define_arrays(indent, defined_arrays) - return string + indent * INDENTATION + 'UNSUPPORTED', defined_arrays + return string + indent * INDENTATION + 'pass', defined_arrays def define_arrays(self, indent: int, defined_arrays: Set[str]) -> Tuple[str, Set[str]]: return '', defined_arrays @@ -114,7 +114,10 @@ def {self.sdfg.label}({self.sdfg.python_signature()}): body = '' undefined_arrays = {name: desc for name, desc in self.containers.items() if name not in defined_arrays} for name, desc in undefined_arrays.items(): - definitions += cindent * INDENTATION + f"{name} = numpy.ndarray({desc.shape}, {TYPECLASS_TO_STRING[desc.dtype].replace('::', '.')})\n" + if isinstance(desc, data.Scalar): + definitions += cindent * INDENTATION + f"{name} = numpy.{desc.dtype.as_numpy_dtype()}(0)\n" + else: + definitions += cindent * INDENTATION + f"{name} = numpy.ndarray({desc.shape}, {TYPECLASS_TO_STRING[desc.dtype].replace('::', '.')})\n" defined_arrays |= undefined_arrays.keys() for child in self.children: substring, defined_arrays = child.as_python(indent + 1, defined_arrays) @@ -275,6 +278,11 @@ def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[s result = indent * INDENTATION + f'if {self.condition.as_string}:\n' string, defined_arrays = super().as_python(indent, defined_arrays) return result + string, defined_arrays + + def is_data_used(self, name: str) -> bool: + result = name in self.condition.get_free_symbols() + result |= super().is_data_used(name) + return result @dataclass From e8c7665290bccced8c97a8c444f059e7ec00e8af Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Mon, 23 Jan 2023 14:07:15 +0100 Subject: [PATCH 57/98] updated test --- tests/schedule_tree/conversion_test.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/schedule_tree/conversion_test.py b/tests/schedule_tree/conversion_test.py index 602644f172..f394a1fa61 100644 --- a/tests/schedule_tree/conversion_test.py +++ b/tests/schedule_tree/conversion_test.py @@ -57,13 +57,10 @@ def dace_azimint_naive(data: dace.float64[N], radius: dace.float64[N]): rng = np.random.default_rng(42) SN, Snpt = 1000, 10 data, radius = rng.random((SN, )), rng.random((SN, )) - # ref = dace_azimint_naive(data, radius, npt=Snpt) + ref = dace_azimint_naive(data, radius, npt=Snpt) sdfg0 = dace_azimint_naive.to_sdfg() tree = as_schedule_tree(sdfg0) - print(tree.as_string()) - pcode, _ = tree.as_python() - print(pcode) sdfg1 = as_sdfg(tree) val = sdfg1(data=data, radius=radius, N=SN, npt=Snpt) From 5fbeffdf628299669cccf07ecd66bdb4fee8a220 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Mon, 23 Jan 2023 15:24:15 +0100 Subject: [PATCH 58/98] Print just `bool` instead of `dace::bool` when converting data to boolean. --- dace/frontend/python/replacements.py | 3 ++- tests/schedule_tree/conversions/map_test.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index bb24a303d8..950227ff80 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -4354,7 +4354,8 @@ def _datatype_converter(sdfg: SDFG, state: SDFGState, arg: UfuncInput, dtype: dt 'name': "_convert_to_{}_".format(dtype.to_string()), 'inputs': ['__inp'], 'outputs': ['__out'], - 'code': "__out = dace.{}(__inp)".format(dtype.to_string()) + 'code': "__out = {}(__inp)".format(f"dace.{dtype.to_string()}" if dtype not in (dace.bool, dace.bool_) + else dtype.to_string()) } tasklet_params = _set_tasklet_params(impl, [arg]) diff --git a/tests/schedule_tree/conversions/map_test.py b/tests/schedule_tree/conversions/map_test.py index 8761d82c3c..7cb89b01aa 100644 --- a/tests/schedule_tree/conversions/map_test.py +++ b/tests/schedule_tree/conversions/map_test.py @@ -26,7 +26,6 @@ def simple_map(A: dace.float32[M, N]): sdfg_pre = simple_map.to_sdfg() tree = as_schedule_tree(sdfg_pre) - print(tree.as_string()) sdfg_post = as_sdfg(tree) rng = np.random.default_rng(42) From 09134bc1f2bd9364e129ce0eedbc68bac0b1210e Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Mon, 23 Jan 2023 15:24:42 +0100 Subject: [PATCH 59/98] Working on if canonicalization and fission tests. --- .../passes/canonicalize_if_test.py | 117 ++++++++++++++++++ .../transformations/if_fission_test.py | 74 +++++++++++ 2 files changed, 191 insertions(+) create mode 100644 tests/schedule_tree/passes/canonicalize_if_test.py create mode 100644 tests/schedule_tree/transformations/if_fission_test.py diff --git a/tests/schedule_tree/passes/canonicalize_if_test.py b/tests/schedule_tree/passes/canonicalize_if_test.py new file mode 100644 index 0000000000..af4b5c10ea --- /dev/null +++ b/tests/schedule_tree/passes/canonicalize_if_test.py @@ -0,0 +1,117 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" Tests canonicalization of If/Elif/Else scopes. """ +import dace +import numpy as np +from dace.sdfg.analysis.schedule_tree import passes, treenodes as tnodes +from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree, as_sdfg + + +class IfCounter(tnodes.ScheduleNodeVisitor): + + if_count: int + elif_count: int + else_count: int + + def __init__(self): + self.if_count = 0 + self.elif_count = 0 + self.else_count = 0 + + def visit_IfScope(self, node: tnodes.IfScope): + self.if_count += 1 + self.generic_visit(node) + + def visit_ElifScope(self, node: tnodes.ElifScope): + self.elif_count += 1 + self.generic_visit(node) + + def visit_ElseScope(self, node: tnodes.ElseScope): + self.else_count += 1 + self.generic_visit(node) + + +def test_ifelifelse_canonicalization(): + + @dace.program + def ifelifelse(c: dace.int64): + out = 0 + if c < 0: + out = c - 1 + elif c == 0: + pass + else: + out = c % 2 + return out + + sdfg_pre = ifelifelse.to_sdfg() + tree = as_schedule_tree(sdfg_pre) + ifcounter_pre = IfCounter() + ifcounter_pre.visit(tree) + ifcount = ifcounter_pre.if_count + ifcounter_pre.elif_count + ifcounter_pre.else_count + + passes.canonicalize_if(tree) + ifcounter_post = IfCounter() + ifcounter_post.visit(tree) + assert ifcounter_post.if_count == ifcount + assert ifcounter_post.elif_count == 0 + assert ifcounter_post.else_count == 0 + + sdfg_post = as_sdfg(tree) + + for c in (-100, 0, 100): + ref = ifelifelse.f(c) + val0 = sdfg_pre(c=c) + val1 = sdfg_post(c=c) + assert val0[0] == ref + assert val1[0] == ref + + +def test_ifelifelse_canonicalization2(): + + @dace.program + def ifelifelse2(c: dace.int64): + out = 0 + if c < 0: + if c < -100: + out = c + 1 + elif c < -50: + out = c + 2 + else: + out = c + 3 + elif c == 0: + pass + else: + if c > 100: + out = c % 2 + elif c > 50: + out = c % 3 + else: + out = c % 4 + return out + + sdfg_pre = ifelifelse2.to_sdfg() + tree = as_schedule_tree(sdfg_pre) + ifcounter_pre = IfCounter() + ifcounter_pre.visit(tree) + ifcount = ifcounter_pre.if_count + ifcounter_pre.elif_count + ifcounter_pre.else_count + + passes.canonicalize_if(tree) + ifcounter_post = IfCounter() + ifcounter_post.visit(tree) + assert ifcounter_post.if_count == ifcount + assert ifcounter_post.elif_count == 0 + assert ifcounter_post.else_count == 0 + + sdfg_post = as_sdfg(tree) + + for c in (-200, -70, -20, 0, 15, 67, 122): + ref = ifelifelse2.f(c) + val0 = sdfg_pre(c=c) + val1 = sdfg_post(c=c) + assert val0[0] == ref + assert val1[0] == ref + + +if __name__ == "__main__": + test_ifelifelse_canonicalization() + test_ifelifelse_canonicalization2() diff --git a/tests/schedule_tree/transformations/if_fission_test.py b/tests/schedule_tree/transformations/if_fission_test.py new file mode 100644 index 0000000000..717d6a2469 --- /dev/null +++ b/tests/schedule_tree/transformations/if_fission_test.py @@ -0,0 +1,74 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" Tests fission of If scopes. """ +import dace +import numpy as np +from dace.sdfg.analysis.schedule_tree import passes, transformations as ttrans, treenodes as tnodes +from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree, as_sdfg + + +class IfCounter(tnodes.ScheduleNodeVisitor): + + if_count: int + elif_count: int + else_count: int + + def __init__(self): + self.if_count = 0 + self.elif_count = 0 + self.else_count = 0 + + def visit_IfScope(self, node: tnodes.IfScope): + self.if_count += 1 + self.generic_visit(node) + + def visit_ElifScope(self, node: tnodes.ElifScope): + self.elif_count += 1 + self.generic_visit(node) + + def visit_ElseScope(self, node: tnodes.ElseScope): + self.else_count += 1 + self.generic_visit(node) + + +def test_if_fission(): + + @dace.program + def ifelifelse(c: dace.int64): + out0, out1 = 0, 0 + if c < 0: + out0 = -5 + out1 = -10 + if c == 0: + pass + if c > 0: + out0 = 5 + out1 = 10 + return out0, out1 + + sdfg_pre = ifelifelse.to_sdfg() + tree = as_schedule_tree(sdfg_pre) + ifcounter_pre = IfCounter() + ifcounter_pre.visit(tree) + if ifcounter_pre.elif_count > 0 or ifcounter_pre.else_count > 0: + passes.canonicalize_if(tree) + ifcounter_post = IfCounter() + ifcounter_post.visit(tree) + assert ifcounter_post.elif_count == 0 + assert ifcounter_post.else_count == 0 + + for child in list(tree.children): + if isinstance(child, tnodes.IfScope): + ttrans.if_fission(child) + + sdfg_post = as_sdfg(tree) + + for c in (-100, 0, 100): + ref = ifelifelse.f(c) + val0 = sdfg_pre(c=c) + val1 = sdfg_post(c=c) + assert val0 == ref + assert val1 == ref + + +if __name__ == "__main__": + test_if_fission() From b818516b21bfd7768ac23193ed08d138f58def85 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 30 Jun 2023 09:53:17 -0700 Subject: [PATCH 60/98] Add schedule-related tests --- tests/schedule_tree/schedule_test.py | 251 +++++++++++++++++++++++++++ 1 file changed, 251 insertions(+) create mode 100644 tests/schedule_tree/schedule_test.py diff --git a/tests/schedule_tree/schedule_test.py b/tests/schedule_tree/schedule_test.py new file mode 100644 index 0000000000..c72822bd1a --- /dev/null +++ b/tests/schedule_tree/schedule_test.py @@ -0,0 +1,251 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import dace +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree +import numpy as np + + +def test_for_in_map_in_for(): + @dace.program + def matmul(A: dace.float32[10, 10], B: dace.float32[10, 10], C: dace.float32[10, 10]): + for i in range(10): + for j in dace.map[0:10]: + atile = dace.define_local([10], dace.float32) + atile[:] = A[i] + for k in range(10): + with dace.tasklet: + a << atile[k] + b << B[k, j] + cin << C[i, j] + c >> C[i, j] + c = cin + a * b + + sdfg = matmul.to_sdfg() + stree = as_schedule_tree(sdfg) + + assert len(stree.children) == 1 # for + fornode = stree.children[0] + assert isinstance(fornode, tn.ForScope) + assert len(fornode.children) == 1 # map + mapnode = fornode.children[0] + assert isinstance(mapnode, tn.MapScope) + assert len(mapnode.children) == 2 # copy, for + copynode, fornode = mapnode.children + assert isinstance(copynode, tn.CopyNode) + assert isinstance(fornode, tn.ForScope) + assert len(fornode.children) == 1 # tasklet + tasklet = fornode.children[0] + assert isinstance(tasklet, tn.TaskletNode) + + +def test_libnode(): + M, N, K = (dace.symbol(s) for s in 'MNK') + + @dace.program + def matmul_lib(a: dace.float64[M, K], b: dace.float64[K, N]): + return a @ b + + sdfg = matmul_lib.to_sdfg() + stree = as_schedule_tree(sdfg) + assert len(stree.children) == 1 + assert isinstance(stree.children[0], tn.LibraryCall) + assert (stree.children[0].as_string() == + '__return[0:M, 0:N] = library MatMul[alpha=1, beta=0](a[0:M, 0:K], b[0:K, 0:N])') + + +def test_nesting(): + @dace.program + def nest2(a: dace.float64[10]): + a += 1 + + @dace.program + def nest1(a: dace.float64[5, 10]): + for i in range(5): + nest2(a[:, i]) + + @dace.program + def main(a: dace.float64[20, 10]): + nest1(a[:5]) + nest1(a[5:10]) + nest1(a[10:15]) + nest1(a[15:]) + + sdfg = main.to_sdfg() + stree = as_schedule_tree(sdfg) + + # Despite two levels of nesting, immediate children are the 4 for loops + assert len(stree.children) == 4 + offsets = ['', '5', '10', '15'] + for fornode, offset in zip(stree.children, offsets): + assert isinstance(fornode, tn.ForScope) + assert len(fornode.children) == 1 # map + mapnode = fornode.children[0] + assert isinstance(mapnode, tn.MapScope) + assert len(mapnode.children) == 1 # tasklet + tasklet = mapnode.children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert offset in str(next(iter(tasklet.in_memlets.values()))) + + +def test_nesting_view(): + @dace.program + def nest2(a: dace.float64[40]): + a += 1 + + @dace.program + def nest1(a): + for i in range(5): + subset = a[:, i, :] + nest2(subset.reshape((40, ))) + + @dace.program + def main(a: dace.float64[20, 10]): + nest1(a.reshape((4, 5, 10))) + + sdfg = main.to_sdfg() + stree = as_schedule_tree(sdfg) + assert any(isinstance(node, tn.ViewNode) for node in stree.children) + + +def test_nesting_nview(): + @dace.program + def nest2(a: dace.float64[40]): + a += 1 + + @dace.program + def nest1(a: dace.float64[4, 5, 10]): + for i in range(5): + nest2(a[:, i, :]) + + @dace.program + def main(a: dace.float64[20, 10]): + nest1(a) + + sdfg = main.to_sdfg() + stree = as_schedule_tree(sdfg) + assert any(isinstance(node, tn.NView) for node in stree.children) + + +def test_irreducible_sub_sdfg(): + sdfg = dace.SDFG('irreducible') + # Add a simple chain + s = sdfg.add_state_after(sdfg.add_state_after(sdfg.add_state())) + # Add an irreducible CFG + s1 = sdfg.add_state() + s2 = sdfg.add_state() + + sdfg.add_edge(s, s1, dace.InterstateEdge('a < b')) + # sdfg.add_edge(s, s2, dace.InterstateEdge('a >= b')) + sdfg.add_edge(s1, s2, dace.InterstateEdge('b > 9')) + sdfg.add_edge(s2, s1, dace.InterstateEdge('b < 19')) + e = sdfg.add_state() + sdfg.add_edge(s1, e, dace.InterstateEdge('a < 0')) + sdfg.add_edge(s2, e, dace.InterstateEdge('b < 0')) + + # Add a loop following general block + sdfg.add_loop(e, sdfg.add_state(), None, 'i', '0', 'i < 10', 'i + 1') + + # TODO: Missing exit in stateif s2->e + # print(as_schedule_tree(sdfg).as_string()) + + +def test_irreducible_in_loops(): + sdfg = dace.SDFG('irreducible') + # Add a simple chain of two for loops with goto from second to first's body + s1 = sdfg.add_state_after(sdfg.add_state_after(sdfg.add_state())) + s2 = sdfg.add_state() + e = sdfg.add_state() + + # Add a loop + l1 = sdfg.add_state() + l2 = sdfg.add_state_after(l1) + sdfg.add_loop(s1, l1, s2, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l2) + + l3 = sdfg.add_state() + l4 = sdfg.add_state_after(l3) + sdfg.add_loop(s2, l3, e, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l4) + + # Irreducible part + sdfg.add_edge(l3, l1, dace.InterstateEdge('i < 5')) + + # TODO: gblock must cover the greatest common scope its labels are in. + # print(as_schedule_tree(sdfg).as_string()) + + +def test_reference(): + sdfg = dace.SDFG('tester') + sdfg.add_symbol('n', dace.int32) + sdfg.add_array('A', [20], dace.float64) + sdfg.add_array('B', [20], dace.float64) + sdfg.add_array('C', [20], dace.float64) + sdfg.add_reference('ref', [20], dace.float64) + + init = sdfg.add_state() + s1 = sdfg.add_state() + s2 = sdfg.add_state() + end = sdfg.add_state() + sdfg.add_edge(init, s1, dace.InterstateEdge('n > 0')) + sdfg.add_edge(init, s2, dace.InterstateEdge('n <= 0')) + sdfg.add_edge(s1, end, dace.InterstateEdge()) + sdfg.add_edge(s2, end, dace.InterstateEdge()) + + s1.add_edge(s1.add_access('A'), None, s1.add_access('ref'), 'set', dace.Memlet('A[0:20]')) + s2.add_edge(s2.add_access('B'), None, s2.add_access('ref'), 'set', dace.Memlet('B[0:20]')) + end.add_nedge(end.add_access('ref'), end.add_access('C'), dace.Memlet('ref[0:20]')) + + # TODO: Align reference memlet + # print(as_schedule_tree(sdfg).as_string()) + + +def test_code_to_code(): + sdfg = dace.SDFG('tester') + sdfg.add_scalar('scal', dace.int32, transient=True) + state = sdfg.add_state() + t1 = state.add_tasklet('a', {}, {'out'}, 'out = 5') + t2 = state.add_tasklet('b', {'inp'}, {}, 'print(inp)', side_effects=True) + state.add_edge(t1, 'out', t2, 'inp', dace.Memlet('scal')) + + # TODO: Nicely print tasklets without outputs + stree = as_schedule_tree(sdfg) + assert len(stree.children) == 2 + assert all(isinstance(c, tn.TaskletNode) for c in stree.children) + + +def test_dyn_map_range(): + H = dace.symbol() + nnz = dace.symbol('nnz') + W = dace.symbol() + + @dace.program + def spmv(A_row: dace.uint32[H + 1], A_col: dace.uint32[nnz], A_val: dace.float32[nnz], x: dace.float32[W]): + b = np.zeros([H], dtype=np.float32) + + for i in dace.map[0:H]: + for j in dace.map[A_row[i]:A_row[i + 1]]: + b[i] += A_val[j] * x[A_col[j]] + + return b + + sdfg = spmv.to_sdfg() + stree = as_schedule_tree(sdfg) + assert len(stree.children) == 2 + assert all(isinstance(c, tn.MapScope) for c in stree.children) + mapscope = stree.children[1] + start, end, dynrangemap = mapscope.children + assert isinstance(start, tn.DynScopeCopyNode) + assert isinstance(end, tn.DynScopeCopyNode) + assert isinstance(dynrangemap, tn.MapScope) + + +if __name__ == '__main__': + test_for_in_map_in_for() + test_libnode() + test_nesting() + test_nesting_view() + test_nesting_nview() + test_irreducible_sub_sdfg() + test_irreducible_in_loops() + test_reference() + test_code_to_code() + test_dyn_map_range() From 79499562d8b07f11343b6e32ed8de6f544199672 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 30 Jun 2023 09:53:52 -0700 Subject: [PATCH 61/98] Fix gblock creation arguments --- dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 8b4020c933..651bdce626 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -482,7 +482,7 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche subnodes.extend(totree(n, node)) if not node.sequential: # Nest in general block - result = [tn.GBlock(children=subnodes)] + result = [tn.GBlock(sdfg, top_level=False, children=subnodes)] else: # Use the sub-nodes directly result = subnodes From b15fa8d40e9b91b5cb920354c359a2a4c4925cf9 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 30 Jun 2023 09:56:27 -0700 Subject: [PATCH 62/98] Fix test names --- .../schedule_tree/conversions/{map_test.py => map_stree_test.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/schedule_tree/conversions/{map_test.py => map_stree_test.py} (100%) diff --git a/tests/schedule_tree/conversions/map_test.py b/tests/schedule_tree/conversions/map_stree_test.py similarity index 100% rename from tests/schedule_tree/conversions/map_test.py rename to tests/schedule_tree/conversions/map_stree_test.py From cef77e9d9a0b7fde7f0853d1a060cfa14a2f973a Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 18 Jul 2023 10:02:31 -0700 Subject: [PATCH 63/98] Fix simple_call for symbolic inputs --- dace/frontend/python/replacements.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index c031a6c5b5..78e1357223 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -606,9 +606,10 @@ def _elementwise(pv: 'ProgramVisitor', def _simple_call(sdfg: SDFG, state: SDFGState, inpname: str, func: str, restype: dace.typeclass = None): """ Implements a simple call of the form `out = func(inp)`. """ + create_input = True if isinstance(inpname, (list, tuple)): # TODO investigate this inpname = inpname[0] - if not isinstance(inpname, str): + if not isinstance(inpname, str) and not symbolic.issymbolic(inpname): # Constant parameter cst = inpname inparr = data.create_datadescriptor(cst) @@ -616,6 +617,10 @@ def _simple_call(sdfg: SDFG, state: SDFGState, inpname: str, func: str, restype: inparr.transient = True sdfg.add_constant(inpname, cst, inparr) sdfg.add_datadesc(inpname, inparr) + elif symbolic.issymbolic(inpname): + dtype = symbolic.symtype(inpname) + inparr = data.Scalar(dtype) + create_input = False else: inparr = sdfg.arrays[inpname] @@ -625,10 +630,16 @@ def _simple_call(sdfg: SDFG, state: SDFGState, inpname: str, func: str, restype: outarr.dtype = restype num_elements = data._prod(inparr.shape) if num_elements == 1: - inp = state.add_read(inpname) + if create_input: + inp = state.add_read(inpname) + inconn_name = '__inp' + else: + inconn_name = symbolic.symstr(inpname) + out = state.add_write(outname) - tasklet = state.add_tasklet(func, {'__inp'}, {'__out'}, '__out = {f}(__inp)'.format(f=func)) - state.add_edge(inp, None, tasklet, '__inp', Memlet.from_array(inpname, inparr)) + tasklet = state.add_tasklet(func, {'__inp'} if create_input else {}, {'__out'}, f'__out = {func}({inconn_name})') + if create_input: + state.add_edge(inp, None, tasklet, '__inp', Memlet.from_array(inpname, inparr)) state.add_edge(tasklet, '__out', out, None, Memlet.from_array(outname, outarr)) else: state.add_mapped_tasklet( From c7ad062e64e54dc0f846d17eb45c15e7a438a8ce Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Wed, 19 Jul 2023 09:15:08 -0700 Subject: [PATCH 64/98] More optional set property support --- dace/properties.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dace/properties.py b/dace/properties.py index 7ca97db204..bf60697b0a 100644 --- a/dace/properties.py +++ b/dace/properties.py @@ -888,6 +888,8 @@ def from_string(s): return [eval(i) for i in re.sub(r"[\{\}\(\)\[\]]", "", s).split(",")] def to_json(self, l): + if l is None: + return None return list(sorted(l)) def from_json(self, l, sdfg=None): From 33f13cf14d2c4dd033db6d5c62f0516bb7a91e83 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Wed, 19 Jul 2023 11:32:44 -0700 Subject: [PATCH 65/98] Traversal methods for strees --- dace/sdfg/analysis/schedule_tree/treenodes.py | 59 +++++++++++++------ 1 file changed, 41 insertions(+), 18 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 1a56334248..877093317d 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -10,7 +10,7 @@ from dace.symbolic import symbol from dace.memlet import Memlet from functools import reduce -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import Dict, Iterator, List, Optional, Set, Tuple, Union INDENTATION = ' ' @@ -27,11 +27,11 @@ class ScheduleTreeNode: def as_string(self, indent: int = 0): return indent * INDENTATION + 'UNSUPPORTED' - + def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: string, defined_arrays = self.define_arrays(indent, defined_arrays) return string + indent * INDENTATION + 'pass', defined_arrays - + def define_arrays(self, indent: int, defined_arrays: Set[str]) -> Tuple[str, Set[str]]: return '', defined_arrays # defined_arrays = defined_arrays or set() @@ -48,7 +48,7 @@ def define_arrays(self, indent: int, defined_arrays: Set[str]) -> Tuple[str, Set # string += indent * INDENTATION + f"{name} = numpy.ndarray({desc.shape}, {TYPECLASS_TO_STRING[desc.dtype].replace('::', '.')})\n" # defined_arrays |= undefined_arrays.keys() # return string, defined_arrays - + def is_data_used(self, name: str, include_symbols: bool = False) -> bool: pass # for child in self.children: @@ -56,6 +56,12 @@ def is_data_used(self, name: str, include_symbols: bool = False) -> bool: # return True # return False + def preorder_traversal(self) -> Iterator['ScheduleTreeNode']: + """ + Traverse tree nodes in a pre-order manner. + """ + yield self + @dataclass class ScheduleTreeScope(ScheduleTreeNode): @@ -66,7 +72,10 @@ class ScheduleTreeScope(ScheduleTreeNode): symbols: Optional[Dict[str, symbol]] = field(default_factory=dict, init=False) # def __init__(self, sdfg: Optional[SDFG] = None, top_level: Optional[bool] = False, children: Optional[List['ScheduleTreeNode']] = None): - def __init__(self, sdfg: Optional[SDFG] = None, top_level: Optional[bool] = False, children: Optional[List['ScheduleTreeNode']] = None): + def __init__(self, + sdfg: Optional[SDFG] = None, + top_level: Optional[bool] = False, + children: Optional[List['ScheduleTreeNode']] = None): self.sdfg = sdfg self.top_level = top_level self.children = children or [] @@ -81,7 +90,7 @@ def __init__(self, sdfg: Optional[SDFG] = None, top_level: Optional[bool] = Fals # if top_level: # self.containers.update({name: copy.deepcopy(desc) for name, desc in sdfg.arrays.items() if not desc.transient}) # # self.containers = {name: copy.deepcopy(container) for name, container in sdfg.arrays.items()} - + # def __post_init__(self): # for child in self.children: # child.parent = self @@ -93,8 +102,20 @@ def __init__(self, sdfg: Optional[SDFG] = None, top_level: Optional[bool] = Fals def as_string(self, indent: int = 0): return '\n'.join([child.as_string(indent + 1) for child in self.children]) - - def as_python(self, indent: int = 0, defined_arrays: Set[str] = None, def_offset: int = 1, sep_defs: bool = False) -> Tuple[str, Set[str]]: + + def preorder_traversal(self) -> Iterator['ScheduleTreeNode']: + """ + Traverse tree nodes in a pre-order manner. + """ + yield from super().preorder_traversal() + for child in self.children: + yield from child.preorder_traversal() + + def as_python(self, + indent: int = 0, + defined_arrays: Set[str] = None, + def_offset: int = 1, + sep_defs: bool = False) -> Tuple[str, Set[str]]: if self.top_level: header = '' for s in self.sdfg.free_symbols: @@ -104,7 +125,7 @@ def as_python(self, indent: int = 0, defined_arrays: Set[str] = None, def_offset def {self.sdfg.label}({self.sdfg.python_signature()}): """ # defined_arrays = set([name for name, desc in self.sdfg.arrays.items() if not desc.transient]) - defined_arrays = set([name for name, desc in self.containers.items() if not desc.transient]) + defined_arrays = set([name for name, desc in self.containers.items() if not desc.transient]) else: header = '' defined_arrays = defined_arrays or set() @@ -128,13 +149,15 @@ def {self.sdfg.label}({self.sdfg.python_signature()}): return definitions, body, defined_arrays else: return header + definitions + body, defined_arrays - + def define_arrays(self, indent: int, defined_arrays: Set[str]) -> Tuple[str, Set[str]]: defined_arrays = defined_arrays or set() string = '' undefined_arrays = {} for sdfg in self.sdfg.all_sdfgs_recursive(): - undefined_arrays.update({name: desc for name, desc in sdfg.arrays.items() if not name in defined_arrays and desc.transient}) + undefined_arrays.update( + {name: desc + for name, desc in sdfg.arrays.items() if not name in defined_arrays and desc.transient}) # undefined_arrays = {name: desc for name, desc in self.sdfg.arrays.items() if not name in defined_arrays and desc.transient} times_used = {name: 0 for name in undefined_arrays} for child in self.children: @@ -149,7 +172,7 @@ def define_arrays(self, indent: int, defined_arrays: Set[str]) -> Tuple[str, Set self.containers[name] = copy.deepcopy(desc) defined_arrays |= undefined_arrays.keys() return string, defined_arrays - + def is_data_used(self, name: str) -> bool: for child in self.children: if child.is_data_used(name): @@ -227,7 +250,7 @@ def as_string(self, indent: int = 0): result = (indent * INDENTATION + f'for {node.itervar} = {node.init}; {node.condition.as_string}; ' f'{node.itervar} = {node.update}:\n') return result + super().as_string(indent) - + def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: node = self.header result = indent * INDENTATION + f'{node.itervar} = {node.init}\n' @@ -273,12 +296,12 @@ class IfScope(ControlFlowScope): def as_string(self, indent: int = 0): result = indent * INDENTATION + f'if {self.condition.as_string}:\n' return result + super().as_string(indent) - + def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: result = indent * INDENTATION + f'if {self.condition.as_string}:\n' string, defined_arrays = super().as_python(indent, defined_arrays) return result + string, defined_arrays - + def is_data_used(self, name: str) -> bool: result = name in self.condition.get_free_symbols() result |= super().is_data_used(name) @@ -490,7 +513,7 @@ def as_string(self, indent: int = 0): wcr = '' return indent * INDENTATION + f'{self.target}{offset} = copy {self.memlet.data}[{self.memlet.subset}]{wcr}' - + def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: if self.memlet.other_subset is not None and any(s != 0 for s in self.memlet.other_subset.min_element()): offset = f'[{self.memlet.other_subset}]' @@ -531,12 +554,12 @@ class ViewNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target} = view {self.memlet} as {self.view_desc.shape}' - + def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: defined_arrays = defined_arrays or set() string, defined_arrays = self.define_arrays(indent, defined_arrays) return string + indent * INDENTATION + f"{self.target} = {self.memlet}", defined_arrays - + def is_data_used(self, name: str) -> bool: # NOTE: View data must not be considered used return name is self.memlet.data From 94781f2cc40c6bb1cda4856454e612d0ecb25531 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Wed, 19 Jul 2023 11:33:24 -0700 Subject: [PATCH 66/98] Fix bug in ConstantProp where a symbol_mapping symbol is used in first state then eliminated completely --- dace/sdfg/sdfg.py | 5 +++-- .../passes/constant_propagation.py | 16 ++++++++++------ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index ba5859388e..b42df62ea4 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -1307,7 +1307,8 @@ def free_symbols(self) -> Set[str]: ordered_states = self.nodes() for state in ordered_states: - free_syms |= state.free_symbols + state_fsyms = state.free_symbols + free_syms |= state_fsyms # Add free inter-state symbols for e in self.out_edges(state): @@ -1315,7 +1316,7 @@ def free_symbols(self) -> Set[str]: # subracting the (true) free symbols from the edge's assignment keys. This way we can correctly # compute the symbols that are used before being assigned. efsyms = e.data.free_symbols - defined_syms |= set(e.data.assignments.keys()) - efsyms + defined_syms |= set(e.data.assignments.keys()) - (efsyms | state_fsyms) used_before_assignment.update(efsyms - defined_syms) free_syms |= efsyms diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index c197adf827..dd2523c005 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -102,12 +102,8 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = for e in sdfg.out_edges(state): e.data.replace_dict(mapping, replace_keys=False) - # If symbols are never unknown any longer, remove from SDFG + # Gather initial propagated symbols result = {k: v for k, v in symbols_replaced.items() if k not in remaining_unknowns} - # Remove from symbol repository - for sym in result: - if sym in sdfg.symbols: - sdfg.remove_symbol(sym) # Remove single-valued symbols from data descriptors (e.g., symbolic array size) sdfg.replace_dict({k: v @@ -121,6 +117,14 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = for sym in intersection: del edge.data.assignments[sym] + # If symbols are never unknown any longer, remove from SDFG + fsyms = sdfg.free_symbols + result = {k: v for k, v in result.items() if k not in fsyms} + for sym in result: + if sym in sdfg.symbols: + # Remove from symbol repository and nested SDFG symbol mapipng + sdfg.remove_symbol(sym) + result = set(result.keys()) if self.recursive: @@ -188,7 +192,7 @@ def collect_constants(self, if len(in_edges) == 1: # Special case, propagate as-is if state not in result: # Condition evaluates to False when state is the start-state result[state] = {} - + # First the prior state if in_edges[0].src in result: # Condition evaluates to False when state is the start-state self._propagate(result[state], result[in_edges[0].src]) From f34b5819616eca9220f2205315396622d1707ffd Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Wed, 19 Jul 2023 11:33:47 -0700 Subject: [PATCH 67/98] No more UB in test --- tests/schedule_tree/schedule_test.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/schedule_tree/schedule_test.py b/tests/schedule_tree/schedule_test.py index c72822bd1a..9827c3cd50 100644 --- a/tests/schedule_tree/schedule_test.py +++ b/tests/schedule_tree/schedule_test.py @@ -7,6 +7,7 @@ def test_for_in_map_in_for(): + @dace.program def matmul(A: dace.float32[10, 10], B: dace.float32[10, 10], C: dace.float32[10, 10]): for i in range(10): @@ -55,6 +56,7 @@ def matmul_lib(a: dace.float64[M, K], b: dace.float64[K, N]): def test_nesting(): + @dace.program def nest2(a: dace.float64[10]): a += 1 @@ -89,6 +91,7 @@ def main(a: dace.float64[20, 10]): def test_nesting_view(): + @dace.program def nest2(a: dace.float64[40]): a += 1 @@ -109,6 +112,7 @@ def main(a: dace.float64[20, 10]): def test_nesting_nview(): + @dace.program def nest2(a: dace.float64[40]): a += 1 @@ -169,6 +173,9 @@ def test_irreducible_in_loops(): # Irreducible part sdfg.add_edge(l3, l1, dace.InterstateEdge('i < 5')) + # Avoiding undefined behavior + sdfg.edges_between(l3, l4)[0].data.condition.as_string = 'i >= 5' + # TODO: gblock must cover the greatest common scope its labels are in. # print(as_schedule_tree(sdfg).as_string()) From 250f62b6de7a7d4ddaaca32b8411340c94960c73 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Wed, 19 Jul 2023 11:51:34 -0700 Subject: [PATCH 68/98] Naming tests --- tests/schedule_tree/naming_test.py | 185 +++++++++++++++++++++++++++++ 1 file changed, 185 insertions(+) create mode 100644 tests/schedule_tree/naming_test.py diff --git a/tests/schedule_tree/naming_test.py b/tests/schedule_tree/naming_test.py new file mode 100644 index 0000000000..d2eddc4bc4 --- /dev/null +++ b/tests/schedule_tree/naming_test.py @@ -0,0 +1,185 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree +from dace.transformation.passes.constant_propagation import ConstantPropagation + +import pytest +from typing import List + + +def _irreducible_loop_to_loop(): + sdfg = dace.SDFG('irreducible') + # Add a simple chain of two for loops with goto from second to first's body + s1 = sdfg.add_state_after(sdfg.add_state_after(sdfg.add_state())) + s2 = sdfg.add_state() + e = sdfg.add_state() + + # Add a loop + l1 = sdfg.add_state() + l2 = sdfg.add_state_after(l1) + sdfg.add_loop(s1, l1, s2, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l2) + + l3 = sdfg.add_state() + l4 = sdfg.add_state_after(l3) + sdfg.add_loop(s2, l3, e, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l4) + + # Irreducible part + sdfg.add_edge(l3, l1, dace.InterstateEdge('i < 5')) + + # Avoiding undefined behavior + sdfg.edges_between(l3, l4)[0].data.condition.as_string = 'i >= 5' + + return sdfg + + +def _nested_irreducible_loops(): + sdfg = _irreducible_loop_to_loop() + nsdfg = _irreducible_loop_to_loop() + + l1 = sdfg.node(5) + l1.add_nested_sdfg(nsdfg, None, {}, {}) + return sdfg + + +def test_clash_states(): + """ + Same test as test_irreducible_in_loops, but all states in the nested SDFG share names with the top SDFG + """ + sdfg = _nested_irreducible_loops() + + stree = as_schedule_tree(sdfg) + unique_names = set() + for node in stree.preorder_traversal(): + if isinstance(node, tn.StateLabel): + if node.state.name in unique_names: + raise NameError('Name clash') + unique_names.add(node.state.name) + + +@pytest.mark.parametrize('constprop', (False, True)) +def test_clash_symbol_mapping(constprop): + sdfg = dace.SDFG('tester') + sdfg.add_array('A', [200], dace.float64) + sdfg.add_symbol('M', dace.int64) + sdfg.add_symbol('N', dace.int64) + sdfg.add_symbol('k', dace.int64) + + state = sdfg.add_state() + state2 = sdfg.add_state() + sdfg.add_edge(state, state2, dace.InterstateEdge(assignments={'k': 'M + 1'})) + + nsdfg = dace.SDFG('nester') + nsdfg.add_symbol('M', dace.int64) + nsdfg.add_symbol('N', dace.int64) + nsdfg.add_symbol('k', dace.int64) + nsdfg.add_array('out', [100], dace.float64) + nsdfg.add_transient('tmp', [100], dace.float64) + nstate = nsdfg.add_state() + nstate2 = nsdfg.add_state() + nsdfg.add_edge(nstate, nstate2, dace.InterstateEdge(assignments={'k': 'M + 1'})) + + # Copy + # The code should end up as `tmp[N:N+2] <- out[M+1:M+3]` + # In the outer SDFG: `tmp[N:N+2] <- A[M+101:M+103]` + r = nstate.add_access('out') + w = nstate.add_access('tmp') + nstate.add_edge(r, None, w, None, dace.Memlet(data='out', subset='k:k+2', other_subset='M:M+2')) + + # Tasklet + # The code should end up as `tmp[M] -> Tasklet -> out[N + 1]` + # In the outer SDFG: `tmp[M] -> Tasklet -> A[N + 101]` + r = nstate2.add_access('tmp') + w = nstate2.add_access('out') + t = nstate2.add_tasklet('dosomething', {'a'}, {'b'}, 'b = a + 1') + nstate2.add_edge(r, None, t, 'a', dace.Memlet('tmp[N]')) + nstate2.add_edge(t, 'b', w, None, dace.Memlet('out[k]')) + + # Connect nested SDFG to parent SDFG with an offset memlet + nsdfg_node = state2.add_nested_sdfg(nsdfg, None, {}, {'out'}, {'N': 'M', 'M': 'N', 'k': 'k'}) + w = state2.add_write('A') + state2.add_edge(nsdfg_node, 'out', w, None, dace.Memlet('A[100:200]')) + + # Get rid of k + if constprop: + ConstantPropagation().apply_pass(sdfg, {}) + + stree = as_schedule_tree(sdfg) + assert len(stree.children) in (2, 4) # Either with assignments or without + + # With assignments + if len(stree.children) == 4: + assert constprop is False + assert isinstance(stree.children[0], tn.AssignNode) + assert isinstance(stree.children[1], tn.CopyNode) + assert isinstance(stree.children[2], tn.AssignNode) + assert isinstance(stree.children[3], tn.TaskletNode) + assert stree.children[1].memlet.data == 'A' + assert str(stree.children[1].memlet.src_subset) == 'k + 100:k + 102' + assert str(stree.children[1].memlet.dst_subset) == 'N:N + 2' + assert stree.children[3].in_memlets['a'].data == 'tmp' + assert str(stree.children[3].in_memlets['a'].src_subset) == 'M' + assert stree.children[3].out_memlets['b'].data == 'A' + assert str(stree.children[3].out_memlets['b'].dst_subset) == 'k + 100' + else: + assert constprop is True + assert isinstance(stree.children[0], tn.CopyNode) + assert isinstance(stree.children[1], tn.TaskletNode) + assert stree.children[0].memlet.data == 'A' + assert str(stree.children[0].memlet.src_subset) == 'M + 101:M + 103' + assert str(stree.children[0].memlet.dst_subset) == 'N:N + 2' + assert stree.children[1].in_memlets['a'].data == 'tmp' + assert str(stree.children[1].in_memlets['a'].src_subset) == 'M' + assert stree.children[1].out_memlets['b'].data == 'A' + assert str(stree.children[1].out_memlets['b'].dst_subset) == 'N + 101' + + +def test_edgecase_symbol_mapping(): + sdfg = dace.SDFG('tester') + sdfg.add_symbol('M', dace.int64) + sdfg.add_symbol('N', dace.int64) + + state = sdfg.add_state() + state2 = sdfg.add_state_after(state) + + nsdfg = dace.SDFG('nester') + nsdfg.add_symbol('M', dace.int64) + nsdfg.add_symbol('N', dace.int64) + nsdfg.add_symbol('k', dace.int64) + nstate = nsdfg.add_state() + nstate2 = nsdfg.add_state() + nsdfg.add_edge(nstate, nstate2, dace.InterstateEdge(assignments={'k': 'M + 1'})) + + state2.add_nested_sdfg(nsdfg, None, {}, {}, {'N': 'M', 'M': 'N', 'k': 'M + 1'}) + + stree = as_schedule_tree(sdfg) + assert len(stree.children) == 1 + assert isinstance(stree.children[0], tn.AssignNode) + # TODO: "assign M + 1 = (N + 1)", target should stay "k" + assert str(stree.children[0].name) == 'k' + + +def test_clash_iteration_symbols(): + sdfg = _nested_irreducible_loops() + + stree = as_schedule_tree(sdfg) + + def _traverse(node: tn.ScheduleTreeScope, scopes: List[str]): + for child in node.children: + if isinstance(child, tn.ForScope): + itervar = child.header.itervar + if itervar in scopes: + raise NameError('Nested scope redefines iteration variable') + _traverse(child, scopes + [itervar]) + elif isinstance(child, tn.ScheduleTreeScope): + _traverse(child, scopes) + + _traverse(stree, []) + + +if __name__ == '__main__': + test_clash_states() + test_clash_symbol_mapping(False) + test_clash_symbol_mapping(True) + test_edgecase_symbol_mapping() + test_clash_iteration_symbols() From 619c21e635057c039cbca07a32a6ebb8b7f0a6b6 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 20 Jul 2023 15:58:14 -0700 Subject: [PATCH 69/98] Move out elements from this PR to others --- dace/frontend/python/replacements.py | 135 +------ dace/sdfg/analysis/schedule_tree/passes.py | 141 ------- .../analysis/schedule_tree/replacements.py | 12 - .../analysis/schedule_tree/sdfg_to_tree.py | 88 ++--- .../analysis/schedule_tree/transformations.py | 357 ------------------ dace/sdfg/analysis/schedule_tree/treenodes.py | 225 ----------- dace/sdfg/analysis/schedule_tree/utils.py | 35 -- tests/schedule_tree/conversion_test.py | 72 ---- .../conversions/map_stree_test.py | 243 ------------ .../schedule_tree/conversions/tasklet_test.py | 141 ------- .../passes/canonicalize_if_test.py | 117 ------ .../transformations/if_fission_test.py | 74 ---- 12 files changed, 47 insertions(+), 1593 deletions(-) delete mode 100644 dace/sdfg/analysis/schedule_tree/replacements.py delete mode 100644 dace/sdfg/analysis/schedule_tree/transformations.py delete mode 100644 dace/sdfg/analysis/schedule_tree/utils.py delete mode 100644 tests/schedule_tree/conversion_test.py delete mode 100644 tests/schedule_tree/conversions/map_stree_test.py delete mode 100644 tests/schedule_tree/conversions/tasklet_test.py delete mode 100644 tests/schedule_tree/passes/canonicalize_if_test.py delete mode 100644 tests/schedule_tree/transformations/if_fission_test.py diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index 78e1357223..d2079b2a35 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -637,7 +637,8 @@ def _simple_call(sdfg: SDFG, state: SDFGState, inpname: str, func: str, restype: inconn_name = symbolic.symstr(inpname) out = state.add_write(outname) - tasklet = state.add_tasklet(func, {'__inp'} if create_input else {}, {'__out'}, f'__out = {func}({inconn_name})') + tasklet = state.add_tasklet(func, {'__inp'} if create_input else {}, {'__out'}, + f'__out = {func}({inconn_name})') if create_input: state.add_edge(inp, None, tasklet, '__inp', Memlet.from_array(inpname, inparr)) state.add_edge(tasklet, '__out', out, None, Memlet.from_array(outname, outarr)) @@ -2158,8 +2159,9 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op res = symbolic.equal(arr1.shape[-1], arr2.shape[-2]) if res is None: - warnings.warn(f'Last mode of first tesnsor/matrix {arr1.shape[-1]} and second-last mode of ' - f'second tensor/matrix {arr2.shape[-2]} may not match', UserWarning) + warnings.warn( + f'Last mode of first tesnsor/matrix {arr1.shape[-1]} and second-last mode of ' + f'second tensor/matrix {arr2.shape[-2]} may not match', UserWarning) elif not res: raise SyntaxError('Matrix dimension mismatch %s != %s' % (arr1.shape[-1], arr2.shape[-2])) @@ -2176,8 +2178,9 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op res = symbolic.equal(arr1.shape[-1], arr2.shape[0]) if res is None: - warnings.warn(f'Number of matrix columns {arr1.shape[-1]} and length of vector {arr2.shape[0]} ' - f'may not match', UserWarning) + warnings.warn( + f'Number of matrix columns {arr1.shape[-1]} and length of vector {arr2.shape[0]} ' + f'may not match', UserWarning) elif not res: raise SyntaxError("Number of matrix columns {} must match" "size of vector {}.".format(arr1.shape[1], arr2.shape[0])) @@ -2188,8 +2191,9 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op res = symbolic.equal(arr1.shape[0], arr2.shape[0]) if res is None: - warnings.warn(f'Length of vector {arr1.shape[0]} and number of matrix rows {arr2.shape[0]} ' - f'may not match', UserWarning) + warnings.warn( + f'Length of vector {arr1.shape[0]} and number of matrix rows {arr2.shape[0]} ' + f'may not match', UserWarning) elif not res: raise SyntaxError("Size of vector {} must match number of matrix " "rows {} must match".format(arr1.shape[0], arr2.shape[0])) @@ -2200,8 +2204,9 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op res = symbolic.equal(arr1.shape[0], arr2.shape[0]) if res is None: - warnings.warn(f'Length of first vector {arr1.shape[0]} and length of second vector {arr2.shape[0]} ' - f'may not match', UserWarning) + warnings.warn( + f'Length of first vector {arr1.shape[0]} and length of second vector {arr2.shape[0]} ' + f'may not match', UserWarning) elif not res: raise SyntaxError("Vectors in vector product must have same size: " "{} vs. {}".format(arr1.shape[0], arr2.shape[0])) @@ -4401,11 +4406,13 @@ def _datatype_converter(sdfg: SDFG, state: SDFGState, arg: UfuncInput, dtype: dt # Set tasklet parameters impl = { - 'name': "_convert_to_{}_".format(dtype.to_string()), + 'name': + "_convert_to_{}_".format(dtype.to_string()), 'inputs': ['__inp'], 'outputs': ['__out'], - 'code': "__out = {}(__inp)".format(f"dace.{dtype.to_string()}" if dtype not in (dace.bool, dace.bool_) - else dtype.to_string()) + 'code': + "__out = {}(__inp)".format(f"dace.{dtype.to_string()}" if dtype not in (dace.bool, + dace.bool_) else dtype.to_string()) } if dtype in (dace.bool, dace.bool_): impl['code'] = "__out = dace.bool_(__inp)" @@ -4733,107 +4740,3 @@ def _op(visitor: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, op1: StringLite for op, method in _boolop_to_method.items(): _makeboolop(op, method) - - -# ScheduleTree-related replacements ################################################################################### - - -@oprepo.replaces('dace.tree.tasklet') -def _tasklet(pv: 'ProgramVisitor', - sdfg: SDFG, - state: SDFGState, - label: StringLiteral, - inputs: Dict[StringLiteral, str], - inputs_wcr: Dict[StringLiteral, Union[None, Callable[[Any, Any], Any]]], - outputs: Dict[StringLiteral, str], - outputs_wcr: Dict[StringLiteral, Union[None, Callable[[Any, Any], Any]]], - code: StringLiteral, - language: dtypes.Language): - - # Extract strings from StringLiterals - label = label.value - inputs = {k.value: v for k, v in inputs.items()} - inputs_wcr = {k.value: v for k, v in inputs_wcr.items()} - outputs = {k.value: v for k, v in outputs.items()} - outputs_wcr = {k.value: v for k, v in outputs_wcr.items()} - code = code.value - - # Create Tasklet - tasklet = state.add_tasklet(label, inputs.keys(), outputs.keys(), code, language) - for conn, name in inputs.items(): - access = state.add_access(name) - memlet = Memlet.from_array(name, sdfg.arrays[name]) - memlet.wcr = inputs_wcr[conn] - state.add_edge(access, None, tasklet, conn, memlet) - for conn, name in outputs.items(): - access = state.add_access(name) - memlet = Memlet.from_array(name, sdfg.arrays[name]) - memlet.wcr = outputs_wcr[conn] - state.add_edge(tasklet, conn, access, None, memlet) - - # Handle scope output - for out in outputs.values(): - for (outer_var, outer_rng, _), (var, rng) in pv.accesses.items(): - if out == var: - if not (outer_var, outer_rng, 'w') in pv.accesses: - pv.accesses[(outer_var, outer_rng, 'w')] = (out, rng) - pv.outputs[out] = (state, Memlet(data=outer_var, subset=outer_rng), set()) - break - - -@oprepo.replaces('dace.tree.library') -def _library(pv: 'ProgramVisitor', - sdfg: SDFG, - state: SDFGState, - ltype: type, - label: StringLiteral, - inputs: Union[Dict[StringLiteral, str], Set[str]], - outputs: Union[Dict[StringLiteral, str], Set[str]], - **kwargs): - - # Extract strings from StringLiterals - label = label.value - if isinstance(inputs, dict): - inputs = {k.value: v for k, v in inputs.items()} - else: - inputs = {i: v for i, v in enumerate(inputs)} - if isinstance(outputs, dict): - outputs = {k.value: v for k, v in outputs.items()} - else: - outputs = {i: v for i, v in enumerate(outputs)} - - # Create LibraryNode - tasklet = ltype(label, **kwargs) - state.add_node(tasklet) - for k, name in inputs.items(): - access = state.add_access(name) - conn = k if isinstance(k, str) else None - state.add_edge(access, None, tasklet, conn, Memlet.from_array(name, sdfg.arrays[name])) - for k, name in outputs.items(): - access = state.add_access(name) - memlet = Memlet.from_array(name, sdfg.arrays[name]) - conn = k if isinstance(k, str) else None - state.add_edge(tasklet, conn, access, None, memlet) - - # Handle scope output - for out in outputs.values(): - for (outer_var, outer_rng, _), (var, rng) in pv.accesses.items(): - if out == var: - if not (outer_var, outer_rng, 'w') in pv.accesses: - pv.accesses[(outer_var, outer_rng, 'w')] = (out, rng) - pv.outputs[out] = (state, Memlet(data=outer_var, subset=outer_rng), set()) - break - - -@oprepo.replaces('dace.tree.copy') -def _copy(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, src: str, dst: str, wcr: str = None): - src_access = state.add_access(src) - dst_access = state.add_access(dst) - state.add_nedge(src_access, dst_access, Memlet.from_array(dst, sdfg.arrays[dst], wcr=None)) - - for (outer_var, outer_rng, _), (var, rng) in pv.accesses.items(): - if dst == var: - if not (outer_var, outer_rng, 'w') in pv.accesses: - pv.accesses[(outer_var, outer_rng, 'w')] = (dst, rng) - pv.outputs[dst] = (state, Memlet(data=outer_var, subset=outer_rng), set()) - break diff --git a/dace/sdfg/analysis/schedule_tree/passes.py b/dace/sdfg/analysis/schedule_tree/passes.py index 81cedffea2..52a58adc32 100644 --- a/dace/sdfg/analysis/schedule_tree/passes.py +++ b/dace/sdfg/analysis/schedule_tree/passes.py @@ -3,9 +3,7 @@ Assortment of passes for schedule trees. """ -from dace import data as dt, Memlet, subsets as sbs, symbolic as sym from dace.sdfg.analysis.schedule_tree import treenodes as tn -from dataclasses import dataclass from typing import Set @@ -63,142 +61,3 @@ def visit(self, node: tn.ScheduleTreeNode): return self.generic_visit(node) return RemoveEmptyScopes().visit(stree) - - -def wcr_to_reduce(stree: tn.ScheduleTreeScope): - """ - Converts WCR assignments to reductions. - - :param stree: The schedule tree to remove WCR assignments from. - """ - - class WCRToReduce(tn.ScheduleNodeTransformer): - - def visit(self, node: tn.ScheduleTreeNode): - - if isinstance(node, tn.TaskletNode): - - wcr_found = False - for _, memlet in node.out_memlets.items(): - if memlet.wcr: - wcr_found = True - break - - if wcr_found: - - loop_found = False - rng = None - idx = None - parent = node.parent - while parent: - if isinstance(parent, (tn.MapScope, tn.ForScope)): - loop_found = True - rng = parent.node.map.range - break - parent = parent.parent - - if loop_found: - - for conn, memlet in node.out_memlets.items(): - if memlet.wcr: - - scope = node.parent - while memlet.data not in scope.containers: - scope = scope.parent - desc = scope.containers[memlet.data] - - shape = rng.size() + list(desc.shape) if not isinstance(desc, dt.Scalar) else rng.size() - parent.containers[f'{memlet.data}_arr'] = dt.Array(desc.dtype, shape, transient=True) - - indices = [(sym.pystr_to_symbolic(s), sym.pystr_to_symbolic(s), 1) for s in parent.node.map.params] - if not isinstance(desc, dt.Scalar): - indices.extend(memlet.subset.ranges) - memlet.subset = sbs.Range(indices) - - from dace.libraries.standard import Reduce - rednode = Reduce(memlet.wcr) - libcall = tn.LibraryCall(rednode, {Memlet.from_array(f'{memlet.data}_arr', parent.containers[f'{memlet.data}_arr'])}, {Memlet.from_array(memlet.data, desc)}) - - memlet.data = f'{memlet.data}_arr' - memlet.wcr = None - - parent.children.append(libcall) - - - return self.generic_visit(node) - - return WCRToReduce().visit(stree) - - -def canonicalize_if(tree: tn.ScheduleTreeScope): - """ - Canonicalizes sequences of if-elif-else scopes to sequences of if scopes. - """ - - from dace.sdfg.nodes import CodeBlock - - class CanonicalizeIf(tn.ScheduleNodeTransformer): - - def visit(self, node: tn.ScheduleTreeNode): - - if isinstance(node, (tn.ElifScope, tn.ElseScope)): - parent = node.parent - assert node in parent.children - node_idx = parent.children.index(node) - - conditions = [] - for curr_node in reversed(parent.children[:node_idx]): - conditions.append(curr_node.condition) - if isinstance(curr_node, tn.IfScope): - break - condition = f"not ({' or '.join([f'({c.as_string})' for c in conditions])})" - if isinstance(node, tn.ElifScope): - condition = f"{condition} and {node.condition.as_string}" - new_node = tn.IfScope(node.sdfg, node.top_level, node.children, CodeBlock(condition)) - new_node.parent = parent - else: - new_node = node - - return self.generic_visit(new_node) - - return CanonicalizeIf().visit(tree) - - -def fission_scopes(node: tn.ScheduleTreeScope): - - from dace.sdfg.analysis.schedule_tree.transformations import loop_fission, if_fission - - @dataclass - class FissionScopes(tn.ScheduleNodeTransformer): - - tree: tn.ScheduleTreeScope - - def visit_IfScope(self, node: tn.IfScope): - return if_fission(node, assume_canonical=True, distribute=True) - - def visit_ForScope(self, node: tn.ForScope): - return loop_fission(node, self.tree) - - def visit_MapScope(self, node: tn.MapScope): - return loop_fission(node, self.tree) - - def visit(self, node: tn.ScheduleTreeNode): - node = self.generic_visit(node) - if isinstance(node, (tn.IfScope, tn.ForScope, tn.MapScope)): - return super().visit(node) - return node - - return FissionScopes(node).visit(node) - - -def validate(node: tn.ScheduleTreeNode) -> bool: - - if isinstance(node, tn.ScheduleTreeScope): - if any(child.parent is not node for child in node.children): - return False - if all(validate(child) for child in node.children): - return True - else: - return False - - return True diff --git a/dace/sdfg/analysis/schedule_tree/replacements.py b/dace/sdfg/analysis/schedule_tree/replacements.py deleted file mode 100644 index 1df8fb4e6a..0000000000 --- a/dace/sdfg/analysis/schedule_tree/replacements.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. -from dace import SDFG, SDFGState -from dace.frontend.common import op_repository as oprepo -from typing import Tuple - - -@oprepo.replaces('dace.tree.library') -def _library(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, ltype: str, label: str, inputs: Tuple[str], outputs: Tuple[str]): - print(ltype) - print(label) - print(inputs) - print(outputs) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 651bdce626..72df0a0472 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -8,7 +8,7 @@ from dace.sdfg.state import SDFGState from dace.sdfg import utils as sdutil, graph as gr from dace.frontend.python.astutils import negate_expr -from dace.sdfg.analysis.schedule_tree import treenodes as tn, passes as stpasses, utils as tutils +from dace.sdfg.analysis.schedule_tree import treenodes as tn, passes as stpasses from dace.transformation.helpers import unsqueeze_memlet from dace.properties import CodeBlock from dace.memlet import Memlet @@ -47,7 +47,7 @@ def dealias_sdfg(sdfg: SDFG): else: inv_replacements[parent_name] = [name] break - + if to_unsqueeze: for parent_name in to_unsqueeze: parent_arr = parent_sdfg.arrays[parent_name] @@ -87,11 +87,16 @@ def dealias_sdfg(sdfg: SDFG): else: e.data.other_subset = subsets.Range.from_array(parent_arr) - if replacements: nsdfg.replace_dict(replacements) - parent_node.in_connectors = {replacements[c] if c in replacements else c: t for c, t in parent_node.in_connectors.items()} - parent_node.out_connectors = {replacements[c] if c in replacements else c: t for c, t in parent_node.out_connectors.items()} + parent_node.in_connectors = { + replacements[c] if c in replacements else c: t + for c, t in parent_node.in_connectors.items() + } + parent_node.out_connectors = { + replacements[c] if c in replacements else c: t + for c, t in parent_node.out_connectors.items() + } for e in parent_state.all_edges(parent_node): if e.src_conn in replacements: e._src_conn = replacements[e.src_conn] @@ -99,22 +104,6 @@ def dealias_sdfg(sdfg: SDFG): e._dst_conn = replacements[e.dst_conn] -def populate_containers(scope: tn.ScheduleTreeScope, defined_arrays: Set[str] = None): - defined_arrays = defined_arrays or set() - if scope.top_level: - scope.containers = {name: copy.deepcopy(desc) for name, desc in scope.sdfg.arrays.items() if not desc.transient} - scope.symbols = dict() - for sdfg in scope.sdfg.all_sdfgs_recursive(): - scope.symbols.update(sdfg.symbols) - defined_arrays = set(scope.containers.keys()) - _, defined_arrays = scope.define_arrays(0, defined_arrays) - for child in scope.children: - child.parent = scope - if isinstance(child, tn.ScheduleTreeScope): - # _, defined_arrays = child.define_arrays(0, defined_arrays) - populate_containers(child, defined_arrays) - - def normalize_memlet(sdfg: SDFG, state: SDFGState, original: gr.MultiConnectorEdge[Memlet], data: str) -> Memlet: """ Normalizes a memlet to a given data descriptor. @@ -166,7 +155,6 @@ def replace_memlets(sdfg: SDFG, array_mapping: Dict[str, Memlet]): if s in array_mapping: repl_dict[s] = str(array_mapping[s]) e.data.replace_dict(repl_dict) - def remove_name_collisions(sdfg: SDFG): @@ -326,7 +314,6 @@ def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[ target_name = innermost_node.data new_memlet = normalize_memlet(sdfg, state, e, outermost_node.data) result[e] = tn.CopyNode(target=target_name, memlet=new_memlet) - # result[e] = tn.CopyNode(sdfg=sdfg, target=target_name, memlet=new_memlet) return result @@ -363,7 +350,10 @@ def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: # Create scope node and add to stack scopes.append(result) subnodes = [] - result.append(NODE_TO_SCOPE_TYPE[type(node)](node=node, sdfg=state.parent, top_level=False, children=subnodes)) + result.append(NODE_TO_SCOPE_TYPE[type(node)](node=node, + sdfg=state.parent, + top_level=False, + children=subnodes)) result = subnodes elif isinstance(node, dace.nodes.ExitNode): result = scopes.pop() @@ -497,7 +487,10 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche if e not in parent.assignments_to_ignore: for aname, aval in e.data.assignments.items(): - edge_body.append(tn.AssignNode(name=aname, value=CodeBlock(aval), edge=InterstateEdge(assignments={aname: aval}))) + edge_body.append( + tn.AssignNode(name=aname, + value=CodeBlock(aval), + edge=InterstateEdge(assignments={aname: aval}))) if not parent.sequential: if e not in parent.gotos_to_ignore: @@ -512,12 +505,18 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche if sdfg.out_degree(node.state) == 1 and parent.sequential: # Conditional state in sequential block! Add "if not condition goto exit" result.append( - tn.StateIfScope(sdfg=sdfg, top_level=False, condition=CodeBlock(negate_expr(e.data.condition)), + tn.StateIfScope(sdfg=sdfg, + top_level=False, + condition=CodeBlock(negate_expr(e.data.condition)), children=[tn.GotoNode(target=None)])) result.extend(edge_body) else: # Add "if condition" with the body above - result.append(tn.StateIfScope(sdfg=sdfg, top_level=False, condition=e.data.condition, children=edge_body)) + result.append( + tn.StateIfScope(sdfg=sdfg, + top_level=False, + condition=e.data.condition, + children=edge_body)) else: result.extend(edge_body) @@ -529,7 +528,8 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche result.append(tn.ElseScope(sdfg=sdfg, top_level=False, children=totree(node.orelse))) elif isinstance(node, cf.IfElseChain): # Add "if" for the first condition, "elif"s for the rest - result.append(tn.IfScope(sdfg=sdfg, top_level=False, condition=node.body[0][0], children=totree(node.body[0][1]))) + result.append( + tn.IfScope(sdfg=sdfg, top_level=False, condition=node.body[0][0], children=totree(node.body[0][1]))) for cond, body in node.body[1:]: result.append(tn.ElifScope(sdfg=sdfg, top_level=False, condition=cond, children=totree(body))) # "else goto exit" @@ -550,44 +550,12 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche # Recursive traversal of the control flow tree result = tn.ScheduleTreeScope(sdfg=sdfg, top_level=True, children=totree(cfg)) - if toplevel: - populate_containers(result) - # Clean up tree stpasses.remove_unused_and_duplicate_labels(result) return result -def as_sdfg(tree: tn.ScheduleTreeScope) -> SDFG: - """ - Converts a ScheduleTree to its SDFG representation. - - :param tree: The ScheduleTree - :return: The ScheduleTree's SDFG representation - """ - - # Write tree as DaCe Python code. - code, _ = tree.as_python() - - # Save DaCe Python code to temporary file. - import tempfile - tmp = tempfile.NamedTemporaryFile(suffix='.py', delete=False) - tmp.write(b'import dace\n') - tmp.write(b'import numpy\n') - tmp.write(bytes(code, encoding='utf-8')) - tmp.close() - - # Load DaCe Python program from temporary file. - import importlib.util - spec = importlib.util.spec_from_file_location(tmp.name.split('/')[-1][:-3], tmp.name) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - prog = eval(f"mod.{tree.sdfg.label}") - - return prog.to_sdfg() - - if __name__ == '__main__': s = time.time() sdfg = SDFG.from_file(sys.argv[1]) diff --git a/dace/sdfg/analysis/schedule_tree/transformations.py b/dace/sdfg/analysis/schedule_tree/transformations.py deleted file mode 100644 index fd7bfad842..0000000000 --- a/dace/sdfg/analysis/schedule_tree/transformations.py +++ /dev/null @@ -1,357 +0,0 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. -import copy -from dace import data as dt, dtypes, Memlet, SDFG, subsets -from dace.sdfg import nodes as dnodes -from dace.sdfg.analysis.schedule_tree import treenodes as tnodes -from dace.sdfg.analysis.schedule_tree import utils as tutils -import re -from typing import Dict, List, Set, Tuple, Union - - -_dataflow_nodes = (tnodes.ViewNode, tnodes.RefSetNode, tnodes.CopyNode, tnodes.DynScopeCopyNode, tnodes.TaskletNode, tnodes.LibraryCall) - - -def _get_loop_size(loop: Union[tnodes.MapScope, tnodes.ForScope]) -> Tuple[str, list, list]: - # Generate loop-related indices, sizes and, strides - if isinstance(loop, tnodes.MapScope): - map = loop.node.map - index = ", ".join(f"{p}/{r[2]}-{r[0]}" if r[2] != 1 else f"{p}-{r[0]}" for p, r in zip(map.params, map.range)) - size = map.range.size() - else: - itervar = loop.header.itervar - start = loop.header.init - # NOTE: Condition expression may be inside parentheses - par = re.search(f"^\s*(\(?)\s*{itervar}", loop.header.condition.as_string).group(1) == '(' - if par: - stop_match = re.search(f"\(\s*{itervar}\s*([<>=]+)\s*(.+)\s*\)", loop.header.condition.as_string) - else: - stop_match = re.search(f"{itervar}\s*([<>=]+)\s*(.+)", loop.header.condition.as_string) - stop_op = stop_match.group(1) - assert stop_op in ("<", "<=", ">", ">=") - stop = stop_match.group(2) - # NOTE: Update expression may be inside parentheses - par = re.search(f"^\s*(\(?)\s*{itervar}", loop.header.update).group(1) == '(' - if par: - step_match = re.search(f"\(\s*{itervar}\s*([(*+-/%)]+)\s*([a-zA-Z0-9_]+)\s*\)", loop.header.update) - else: - step_match = re.search(f"{itervar}\s*([(*+-/%)]+)\s*([a-zA-Z0-9_]+)", loop.header.update) - try: - step_op = step_match.group(1) - step = step_match.group(2) - if step_op == '+': - step = int(step) - index = f"{itervar}/{step}-{start}" if step != 1 else f"{itervar}-{start}" - else: - raise ValueError - except (AttributeError, ValueError): - step = 1 if '<' in stop_op else -1 - index = itervar - if "=" in stop_op: - stop = f"{stop} + ({step})" - size = subsets.Range.from_string(f"{start}:{stop}:{step}").size() - - strides = [1] * len(size) - for i in range(len(size) - 2, -1, -1): - strides[i] = strides[i+1] * size[i+1] - - return index, size, strides - - -def _update_memlets(data: Dict[str, dt.Data], memlets: Dict[str, Memlet], index: str, replace: Dict[str, bool]): - for conn, memlet in memlets.items(): - if memlet.data in data: - subset = index if replace[memlet.data] else f"{index}, {memlet.subset}" - memlets[conn] = Memlet(data=memlet.data, subset=subset) - - -def _augment_data(data: Set[str], loop: Union[tnodes.MapScope, tnodes.ForScope], tree: tnodes.ScheduleTreeNode): - - # # Generate loop-related indices, sizes and, strides - # if isinstance(loop, tnodes.MapScope): - # map = loop.node.map - # index = ", ".join(f"{p}/{r[2]}-{r[0]}" if r[2] != 1 else f"{p}-{r[0]}" for p, r in zip(map.params, map.range)) - # size = map.range.size() - # else: - # itervar = loop.header.itervar - # start = loop.header.init - # # NOTE: Condition expression may be inside parentheses - # par = re.search(f"^\s*(\(?)\s*{itervar}", loop.header.condition.as_string).group(1) == '(' - # if par: - # stop_match = re.search(f"\(\s*{itervar}\s*([<>=]+)\s*(.+)\s*\)", loop.header.condition.as_string) - # else: - # stop_match = re.search(f"{itervar}\s*([<>=]+)\s*(.+)", loop.header.condition.as_string) - # stop_op = stop_match.group(1) - # assert stop_op in ("<", "<=", ">", ">=") - # stop = stop_match.group(2) - # # NOTE: Update expression may be inside parentheses - # par = re.search(f"^\s*(\(?)\s*{itervar}", loop.header.update).group(1) == '(' - # if par: - # step_match = re.search(f"\(\s*{itervar}\s*([(*+-/%)]+)\s*([a-zA-Z0-9_]+)\s*\)", loop.header.update) - # else: - # step_match = re.search(f"{itervar}\s*([(*+-/%)]+)\s*([a-zA-Z0-9_]+)", loop.header.update) - # try: - # step_op = step_match.group(1) - # step = step_match.group(2) - # if step_op == '+': - # step = int(step) - # index = f"{itervar}/{step}-{start}" if step != 1 else f"{itervar}-{start}" - # else: - # raise ValueError - # except (AttributeError, ValueError): - # step = 1 if '<' in stop_op else -1 - # index = itervar - # if "=" in stop_op: - # stop = f"{stop} + ({step})" - # size = subsets.Range.from_string(f"{start}:{stop}:{step}").size() - - # strides = [1] * len(size) - # for i in range(len(size) - 2, -1, -1): - # strides[i] = strides[i+1] * size[i+1] - - index, size, strides = _get_loop_size(loop) - - # Augment data descriptors - replace = dict() - for name in data: - desc = loop.containers[name] - if isinstance(desc, dt.Scalar): - - desc = dt.Array(desc.dtype, size, True, storage=desc.storage) - replace[name] = True - else: - mult = desc.shape[0] - desc.shape = (*size, *desc.shape) - new_strides = [s * mult for s in strides] - desc.strides = (*new_strides, *desc.strides) - replace[name] = False - del loop.containers[name] - loop.parent.containers[name] = desc - - # Update memlets - frontier = list(tree.children) - while frontier: - node = frontier.pop() - if isinstance(node, _dataflow_nodes): - try: - _update_memlets(data, node.in_memlets, index, replace) - _update_memlets(data, node.out_memlets, index, replace) - except AttributeError: - if node.memlet.data in data: - subset = index if replace[node.memlet.data] else f"{index}, {node.memlet.subset}" - node.memlet = Memlet(data=node.memlet.data, subset=subset) - if hasattr(node, 'children'): - frontier.extend(node.children) - - - -def loop_fission(loop: Union[tnodes.MapScope, tnodes.ForScope], tree: tnodes.ScheduleTreeNode) -> List[Union[tnodes.MapScope, tnodes.ForScope]]: - """ - Applies the LoopFission transformation to the input MapScope or ForScope. - - :param loop: The MapScope or ForScope. - :param tree: The ScheduleTree. - :return: True if the transformation applies successfully, otherwise False. - """ - - sdfg = loop.sdfg - - #################################### - # Check if LoopFission can be applied - - # Basic check: cannot fission an empty MapScope/ForScope or one that has a single child. - partition = tutils.partition_scope_body(loop) - if len(partition) < 2: - return [loop] - - index, _, _ = _get_loop_size(loop) - - data_to_augment = set() - assignments = dict() - frontier = list(partition) - while len(frontier) > 0: - scope = frontier.pop() - if isinstance(scope, _dataflow_nodes): - try: - for _, memlet in scope.out_memlets.items(): - if memlet.data in loop.containers: - data_to_augment.add(memlet.data) - except AttributeError: - if scope.target in loop.containers: - data_to_augment.add(scope.target) - elif isinstance(scope, tnodes.AssignNode): - symbol = tree.symbols[scope.name] - loop.containers[f"{scope.name}_arr"] = dt.Scalar(symbol.dtype, transient=True) - data_to_augment.add(f"{scope.name}_arr") - repl_dict = {scope.name: '__out'} - out_memlets = {'__out': Memlet(data=f"{scope.name}_arr", subset='0')} - in_memlets = dict() - for i, memlet in enumerate(scope.edge.get_read_memlets(scope.parent.sdfg.arrays)): - repl_dict[str(memlet)] = f'__in{i}' - in_memlets[f'__in{i}'] = memlet - scope.edge.replace_dict(repl_dict) - tasklet = dnodes.Tasklet('some_label', in_memlets.keys(), {'__out'}, - f"__out = {scope.edge.assignments['__out']}") - tnode = tnodes.TaskletNode(tasklet, in_memlets, out_memlets) - tnode.parent = loop - idx = loop.children.index(scope) - loop.children[idx] = tnode - idx = partition.index(scope) - partition[idx] = tnode - edge = copy.deepcopy(scope.edge) - edge.assignments['__out'] = f"{scope.name}_arr[{index}]" - assignments[scope.name] = (dnodes.CodeBlock(f"{scope.name}[{index}]"), edge, idx) - elif hasattr(scope, 'children'): - frontier.extend(scope.children) - _augment_data(data_to_augment, loop, tree) - print(data_to_augment) - - - new_scopes = [] - # while partition: - # child = partition.pop(0) - for i, child in enumerate(partition): - if not isinstance(child, list): - child = [child] - - for c in list(child): - idx = child.index(c) - # Reverse access? - for name, (value, edge, index) in assignments.items(): - if index == i: - continue - if c.is_data_used(name, True): - child.insert(idx, tnodes.AssignNode(f"{name}", copy.deepcopy(value), copy.deepcopy(edge))) - - if isinstance(loop, tnodes.MapScope): - scope = tnodes.MapScope(sdfg, False, child, copy.deepcopy(loop.node)) - else: - scope = tnodes.ForScope(sdfg, False, child, copy.copy(loop.header)) - for child in scope.children: - child.parent = scope - if isinstance(child, tnodes.ScheduleTreeScope): - scope.containers.update(child.containers) - scope.parent = loop.parent - new_scopes.append(scope) - - return new_scopes - - -def map_fission(map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode) -> bool: - """ - Applies the MapFission transformation to the input MapScope. - - :param map_scope: The MapScope. - :param tree: The ScheduleTree. - :return: True if the transformation applies successfully, otherwise False. - """ - - sdfg = map_scope.sdfg - - #################################### - # Check if MapFission can be applied - - # Basic check: cannot fission an empty MapScope or one that has a single dataflow child. - num_children = len(map_scope.children) - if num_children == 0 or (num_children == 1 and isinstance(map_scope.children[0], _dataflow_nodes)): - return False - - # State-scope check: if the body consists of a single state-scope, certain conditions apply. - partition = tutils.partition_scope_body(map_scope) - if len(partition) == 1: - - child = partition[0] - conditions = [] - if isinstance(child, list): - # If-Elif-Else-Scope - for c in child: - if isinstance(c, (tnodes.IfScope, tnodes.ElifScope)): - conditions.append(c.condition) - elif isinstance(child, tnodes.ForScope): - conditions.append(child.header.condition) - elif isinstance(child, tnodes.WhileScope): - conditions.append(child.header.test) - - for cond in conditions: - map = map_scope.node.map - if any(p in cond.get_free_symbols() for p in map.params): - return False - - # TODO: How to run the check below in the ScheduleTree? - # for s in cond.get_free_symbols(): - # for e in graph.edges_by_connector(self.nested_sdfg, s): - # if any(p in e.data.free_symbols for p in map.params): - # return False - - # data_to_augment = dict() - data_to_augment = set() - frontier = list(partition) - while len(frontier) > 0: - scope = frontier.pop() - if isinstance(scope, _dataflow_nodes): - try: - for _, memlet in scope.out_memlets.items(): - if memlet.data in map_scope.containers: - data_to_augment.add(memlet.data) - # if scope.sdfg.arrays[memlet.data].transient: - # data_to_augment[memlet.data] = scope.sdfg - except AttributeError: - if scope.target in map_scope.containers: - data_to_augment.add(scope.target) - # if scope.target in scope.sdfg.arrays and scope.sdfg.arrays[scope.target].transient: - # data_to_augment[scope.target] = scope.sdfg - if hasattr(scope, 'children'): - frontier.extend(scope.children) - _augment_data(data_to_augment, map_scope, tree) - - parent_scope = map_scope.parent - idx = parent_scope.children.index(map_scope) - parent_scope.children.pop(idx) - while len(partition) > 0: - child_scope = partition.pop() - if not isinstance(child_scope, list): - child_scope = [child_scope] - scope = tnodes.MapScope(sdfg, False, child_scope, copy.deepcopy(map_scope.node)) - scope.parent = parent_scope - parent_scope.children.insert(idx, scope) - - return True - - -def if_fission(if_scope: tnodes.IfScope, assume_canonical: bool = False, distribute: bool = False) -> List[tnodes.IfScope]: - - from dace.sdfg.nodes import CodeBlock - - # Check transformation conditions - # Scope must not have subsequent elif or else scopes - if not assume_canonical: - idx = if_scope.parent.children.index(if_scope) - if len(if_scope.parent.children) > idx + 1 and isinstance(if_scope.parent.children[idx+1], - (tnodes.ElifScope, tnodes.ElseScope)): - return [if_scope] - if len(if_scope.children) < 2 and not (isinstance(if_scope.children[0], tnodes.IfScope) and distribute): - return [if_scope] - - new_scopes = [] - partition = tutils.partition_scope_body(if_scope) - while partition: - child = partition.pop(0) - if isinstance(child, list) and len(child) == 1 and isinstance(child[0], tnodes.IfScope) and distribute: - scope = tnodes.IfScope(if_scope.sdfg, False, child[0].children, CodeBlock(f"{if_scope.condition.as_string} and {child[0].condition.as_string}")) - scope.containers.update(child[0].containers) - else: - if not isinstance(child, list): - child = [child] - scope = tnodes.IfScope(if_scope.sdfg, False, child, copy.deepcopy(if_scope.condition)) - for child in scope.children: - child.parent = scope - if isinstance(child, tnodes.ScheduleTreeScope): - scope.containers.update(child.containers) - scope.parent = if_scope.parent - new_scopes.append(scope) - - return new_scopes - - -def wcr_to_reduce(map_scope: tnodes.MapScope, tree: tnodes.ScheduleTreeNode) -> bool: - - pass \ No newline at end of file diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 877093317d..5d200efc11 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -28,34 +28,6 @@ class ScheduleTreeNode: def as_string(self, indent: int = 0): return indent * INDENTATION + 'UNSUPPORTED' - def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: - string, defined_arrays = self.define_arrays(indent, defined_arrays) - return string + indent * INDENTATION + 'pass', defined_arrays - - def define_arrays(self, indent: int, defined_arrays: Set[str]) -> Tuple[str, Set[str]]: - return '', defined_arrays - # defined_arrays = defined_arrays or set() - # string = '' - # undefined_arrays = {name: desc for name, desc in self.sdfg.arrays.items() if not name in defined_arrays and desc.transient} - # if hasattr(self, 'children'): - # times_used = {name: 0 for name in undefined_arrays} - # for child in self.children: - # for name in undefined_arrays: - # if child.is_data_used(name): - # times_used[name] += 1 - # undefined_arrays = {name: desc for name, desc in undefined_arrays.items() if times_used[name] > 1} - # for name, desc in undefined_arrays.items(): - # string += indent * INDENTATION + f"{name} = numpy.ndarray({desc.shape}, {TYPECLASS_TO_STRING[desc.dtype].replace('::', '.')})\n" - # defined_arrays |= undefined_arrays.keys() - # return string, defined_arrays - - def is_data_used(self, name: str, include_symbols: bool = False) -> bool: - pass - # for child in self.children: - # if child.is_data_used(name): - # return True - # return False - def preorder_traversal(self) -> Iterator['ScheduleTreeNode']: """ Traverse tree nodes in a pre-order manner. @@ -71,7 +43,6 @@ class ScheduleTreeScope(ScheduleTreeNode): containers: Optional[Dict[str, data.Data]] = field(default_factory=dict, init=False) symbols: Optional[Dict[str, symbol]] = field(default_factory=dict, init=False) - # def __init__(self, sdfg: Optional[SDFG] = None, top_level: Optional[bool] = False, children: Optional[List['ScheduleTreeNode']] = None): def __init__(self, sdfg: Optional[SDFG] = None, top_level: Optional[bool] = False, @@ -82,23 +53,6 @@ def __init__(self, if self.children: for child in children: child.parent = self - # self.__post_init__() - # for child in children: - # child.parent = self - # _, defined_arrays = self.define_arrays(0, set()) - # self.containers = {name: copy.deepcopy(sdfg.arrays[name]) for name in defined_arrays} - # if top_level: - # self.containers.update({name: copy.deepcopy(desc) for name, desc in sdfg.arrays.items() if not desc.transient}) - # # self.containers = {name: copy.deepcopy(container) for name, container in sdfg.arrays.items()} - - # def __post_init__(self): - # for child in self.children: - # child.parent = self - # _, defined_arrays = self.define_arrays(0, set()) - # self.containers = {name: copy.deepcopy(self.sdfg.arrays[name]) for name in defined_arrays} - # if self.top_level: - # self.containers.update({name: copy.deepcopy(desc) for name, desc in self.sdfg.arrays.items() if not desc.transient}) - # # self.containers = {name: copy.deepcopy(container) for name, container in sdfg.arrays.items()} def as_string(self, indent: int = 0): return '\n'.join([child.as_string(indent + 1) for child in self.children]) @@ -111,74 +65,6 @@ def preorder_traversal(self) -> Iterator['ScheduleTreeNode']: for child in self.children: yield from child.preorder_traversal() - def as_python(self, - indent: int = 0, - defined_arrays: Set[str] = None, - def_offset: int = 1, - sep_defs: bool = False) -> Tuple[str, Set[str]]: - if self.top_level: - header = '' - for s in self.sdfg.free_symbols: - header += f"{s} = dace.symbol('{s}', {TYPECLASS_TO_STRING[self.sdfg.symbols[s]].replace('::', '.')})\n" - header += f""" -@dace.program -def {self.sdfg.label}({self.sdfg.python_signature()}): -""" - # defined_arrays = set([name for name, desc in self.sdfg.arrays.items() if not desc.transient]) - defined_arrays = set([name for name, desc in self.containers.items() if not desc.transient]) - else: - header = '' - defined_arrays = defined_arrays or set() - cindent = indent + def_offset - # string, defined_arrays = self.define_arrays(indent + 1, defined_arrays) - definitions = '' - body = '' - undefined_arrays = {name: desc for name, desc in self.containers.items() if name not in defined_arrays} - for name, desc in undefined_arrays.items(): - if isinstance(desc, data.Scalar): - definitions += cindent * INDENTATION + f"{name} = numpy.{desc.dtype.as_numpy_dtype()}(0)\n" - else: - definitions += cindent * INDENTATION + f"{name} = numpy.ndarray({desc.shape}, {TYPECLASS_TO_STRING[desc.dtype].replace('::', '.')})\n" - defined_arrays |= undefined_arrays.keys() - for child in self.children: - substring, defined_arrays = child.as_python(indent + 1, defined_arrays) - body += substring - if body[-1] != '\n': - body += '\n' - if sep_defs: - return definitions, body, defined_arrays - else: - return header + definitions + body, defined_arrays - - def define_arrays(self, indent: int, defined_arrays: Set[str]) -> Tuple[str, Set[str]]: - defined_arrays = defined_arrays or set() - string = '' - undefined_arrays = {} - for sdfg in self.sdfg.all_sdfgs_recursive(): - undefined_arrays.update( - {name: desc - for name, desc in sdfg.arrays.items() if not name in defined_arrays and desc.transient}) - # undefined_arrays = {name: desc for name, desc in self.sdfg.arrays.items() if not name in defined_arrays and desc.transient} - times_used = {name: 0 for name in undefined_arrays} - for child in self.children: - for name in undefined_arrays: - if child.is_data_used(name): - times_used[name] += 1 - undefined_arrays = {name: desc for name, desc in undefined_arrays.items() if times_used[name] > 1} - if not self.containers: - self.containers = {} - for name, desc in undefined_arrays.items(): - string += indent * INDENTATION + f"{name} = numpy.ndarray({desc.shape}, {TYPECLASS_TO_STRING[desc.dtype].replace('::', '.')})\n" - self.containers[name] = copy.deepcopy(desc) - defined_arrays |= undefined_arrays.keys() - return string, defined_arrays - - def is_data_used(self, name: str) -> bool: - for child in self.children: - if child.is_data_used(name): - return True - return False - # TODO: Get input/output memlets? @@ -204,8 +90,6 @@ def as_string(self, indent: int = 0): result = indent * INDENTATION + 'gblock:\n' return result + super().as_string(indent) - pass - @dataclass class StateLabel(ScheduleTreeNode): @@ -251,15 +135,6 @@ def as_string(self, indent: int = 0): f'{node.itervar} = {node.update}:\n') return result + super().as_string(indent) - def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: - node = self.header - result = indent * INDENTATION + f'{node.itervar} = {node.init}\n' - result += indent * INDENTATION + f'while {node.condition.as_string}:\n' - defs, body, defined_arrays = super().as_python(indent, defined_arrays, def_offset=0, sep_defs=True) - result = defs + result + body - result += (indent + 1) * INDENTATION + f'{node.itervar} = {node.update}\n' - return result, defined_arrays - @dataclass class WhileScope(ControlFlowScope): @@ -297,16 +172,6 @@ def as_string(self, indent: int = 0): result = indent * INDENTATION + f'if {self.condition.as_string}:\n' return result + super().as_string(indent) - def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: - result = indent * INDENTATION + f'if {self.condition.as_string}:\n' - string, defined_arrays = super().as_python(indent, defined_arrays) - return result + string, defined_arrays - - def is_data_used(self, name: str) -> bool: - result = name in self.condition.get_free_symbols() - result |= super().is_data_used(name) - return result - @dataclass class StateIfScope(IfScope): @@ -373,12 +238,6 @@ def as_string(self, indent: int = 0): result = indent * INDENTATION + f'map {", ".join(self.node.map.params)} in [{rangestr}]:\n' return result + super().as_string(indent) - def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: - rangestr = ', '.join(subsets.Range.dim_to_string(d) for d in self.node.map.range) - result = indent * INDENTATION + f'for {", ".join(self.node.map.params)} in dace.map[{rangestr}]:\n' - string, defined_arrays = super().as_python(indent, defined_arrays) - return result + string, defined_arrays - @dataclass class ConsumeScope(DataflowScope): @@ -405,14 +264,6 @@ def as_string(self, indent: int = 0): return result + super().as_string(indent) -def _memlet_to_str(memlet: Memlet) -> str: - assert memlet.other_subset == None - wcr = "" - if memlet.wcr: - wcr = f"({reduce(lambda x, y: x * y, memlet.subset.size())}, {memlet.wcr})" - return f"{memlet.data}{wcr}[{memlet.subset}]" - - @dataclass class TaskletNode(ScheduleTreeNode): node: nodes.Tasklet @@ -424,28 +275,6 @@ def as_string(self, indent: int = 0): out_memlets = ', '.join(f'{v}' for v in self.out_memlets.values()) return indent * INDENTATION + f'{out_memlets} = tasklet({in_memlets})' - def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: - explicit_dataflow = indent * INDENTATION + "with dace.tasklet:\n" - for conn, memlet in self.in_memlets.items(): - explicit_dataflow += (indent + 1) * INDENTATION + f"{conn} << {_memlet_to_str(memlet)}\n" - for conn, memlet in self.out_memlets.items(): - explicit_dataflow += (indent + 1) * INDENTATION + f"{conn} >> {_memlet_to_str(memlet)}\n" - code = self.node.code.as_string.replace('\n', f"\n{(indent + 1) * INDENTATION}") - explicit_dataflow += (indent + 1) * INDENTATION + code - defined_arrays = defined_arrays or set() - string, defined_arrays = self.define_arrays(indent, defined_arrays) - return string + explicit_dataflow, defined_arrays - - def is_data_used(self, name: str, include_symbols: bool = False) -> bool: - used_data = set([memlet.data for memlet in self.in_memlets.values()]) - used_data |= set([memlet.data for memlet in self.out_memlets.values()]) - if include_symbols: - for memlet in self.in_memlets.values(): - used_data |= memlet.subset.free_symbols - if memlet.other_subset: - used_data |= memlet.other_subset.free_symbols - return name in used_data - @dataclass class LibraryCall(ScheduleTreeNode): @@ -468,34 +297,6 @@ def as_string(self, indent: int = 0): if v.owner not in {nodes.Node, nodes.CodeNode, nodes.LibraryNode}) return indent * INDENTATION + f'{out_memlets} = library {libname}[{own_properties}]({in_memlets})' - def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: - if isinstance(self.in_memlets, set): - in_memlets = ', '.join(f'{v}' for v in self.in_memlets) - else: - in_memlets = ', '.join(f"'{k}': {v}" for k, v in self.in_memlets.items()) - if isinstance(self.out_memlets, set): - out_memlets = ', '.join(f'{v}' for v in self.out_memlets) - else: - out_memlets = ', '.join(f"'{k}': {v}" for k, v in self.out_memlets.items()) - libname = type(self.node).__module__ + '.' + type(self.node).__qualname__ - # Get the properties of the library node without its superclasses - own_properties = ', '.join(f'{k}={getattr(self.node, k)}' for k, v in self.node.__properties__.items() - if v.owner not in {nodes.Node, nodes.CodeNode, nodes.LibraryNode}) - defined_arrays = defined_arrays or set() - string, defined_arrays = self.define_arrays(indent, defined_arrays) - return string + indent * INDENTATION + f"dace.tree.library(ltype={libname}, label='{self.node.label}', inputs={{{in_memlets}}}, outputs={{{out_memlets}}}, {own_properties})", defined_arrays - - def is_data_used(self, name: str) -> bool: - if isinstance(self.in_memlets, set): - used_data = set([memlet.data for memlet in self.in_memlets]) - else: - used_data = set([memlet.data for memlet in self.in_memlets.values()]) - if isinstance(self.out_memlets, set): - used_data |= set([memlet.data for memlet in self.out_memlets]) - else: - used_data |= set([memlet.data for memlet in self.out_memlets.values()]) - return name in used_data - @dataclass class CopyNode(ScheduleTreeNode): @@ -514,23 +315,6 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target}{offset} = copy {self.memlet.data}[{self.memlet.subset}]{wcr}' - def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: - if self.memlet.other_subset is not None and any(s != 0 for s in self.memlet.other_subset.min_element()): - offset = f'[{self.memlet.other_subset}]' - else: - offset = f'[{self.memlet.subset}]' - if self.memlet.wcr is not None: - wcr = f' with {self.memlet.wcr}' - else: - wcr = '' - - defined_arrays = defined_arrays or set() - string, defined_arrays = self.define_arrays(indent, defined_arrays) - return string + indent * INDENTATION + f'dace.tree.copy(src={self.memlet.data}[{self.memlet.subset}], dst={self.target}{offset}, wcr={self.memlet.wcr})', defined_arrays - - def is_data_used(self, name: str) -> bool: - return name is self.memlet.data or name is self.target - @dataclass class DynScopeCopyNode(ScheduleTreeNode): @@ -555,15 +339,6 @@ class ViewNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target} = view {self.memlet} as {self.view_desc.shape}' - def as_python(self, indent: int = 0, defined_arrays: Set[str] = None) -> Tuple[str, Set[str]]: - defined_arrays = defined_arrays or set() - string, defined_arrays = self.define_arrays(indent, defined_arrays) - return string + indent * INDENTATION + f"{self.target} = {self.memlet}", defined_arrays - - def is_data_used(self, name: str) -> bool: - # NOTE: View data must not be considered used - return name is self.memlet.data - @dataclass class NView(ViewNode): diff --git a/dace/sdfg/analysis/schedule_tree/utils.py b/dace/sdfg/analysis/schedule_tree/utils.py deleted file mode 100644 index 2257c203c8..0000000000 --- a/dace/sdfg/analysis/schedule_tree/utils.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. -from dace.sdfg import SDFG, SDFGState -from dace.sdfg import nodes as snodes -from dace.sdfg.analysis.schedule_tree import treenodes as tnodes -from dace.sdfg.graph import NodeNotFoundError -from typing import List, Tuple, Union - - -def partition_scope_body(scope: tnodes.ScheduleTreeScope) -> List[Union[tnodes.ScheduleTreeNode, List[tnodes.ScheduleTreeNode]]]: - """ - Partitions a scope's body to ScheduleTree nodes, when they define their own sub-scope, and lists of ScheduleTree - nodes that are children to the same sub-scope. For example, IfScopes, ElifScopes, and ElseScopes are generally - children to a general "If-Elif-Else-Scope". - - :param scope: The scope. - :return: A list of (lists of) ScheduleTree nodes. - """ - - num_children = len(scope.children) - partition = [] - i = 0 - while i < num_children: - child = scope.children[i] - if isinstance(child, tnodes.IfScope): - # Start If-Elif-Else-Scope. - ifelse = [child] - i += 1 - while i < num_children and isinstance(scope.children[i], (tnodes.ElifScope, tnodes.ElseScope)): - ifelse.append(child) - i += 1 - partition.append(ifelse) - else: - partition.append(child) - i += 1 - return partition diff --git a/tests/schedule_tree/conversion_test.py b/tests/schedule_tree/conversion_test.py deleted file mode 100644 index f394a1fa61..0000000000 --- a/tests/schedule_tree/conversion_test.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. -import dace -import numpy as np -from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree, as_sdfg - - -# TODO: The test fails because of the ambiguity when having access nodes inside a MapScope but on the same SDFG level. -def test_map_with_tasklet_and_library(): - - N = dace.symbol('N') - @dace.program - def map_with_tasklet_and_library(A: dace.float32[N, 5, 5], B: dace.float32[N, 5, 5], cst: dace.int32): - out = np.ndarray((N, 5, 5), dtype=dace.float32) - for i in dace.map[0:N]: - out[i] = cst * (A[i] @ B[i]) - return out - - rng = np.random.default_rng(42) - A = rng.random((10, 5, 5), dtype=np.float32) - B = rng.random((10, 5, 5), dtype=np.float32) - cst = rng.integers(0, 100, dtype=np.int32) - ref = cst * (A @ B) - - val0 = map_with_tasklet_and_library(A, B, cst) - sdfg0 = map_with_tasklet_and_library.to_sdfg() - tree = as_schedule_tree(sdfg0) - pcode, _ = tree.as_python() - print(pcode) - sdfg1 = as_sdfg(tree) - val1 = sdfg1(A=A, B=B, cst=cst, N=A.shape[0]) - - assert np.allclose(val0, ref) - assert np.allclose(val1, ref) - - -def test_azimint_naive(): - - N, npt = (dace.symbol(s) for s in ('N', 'npt')) - @dace.program - def dace_azimint_naive(data: dace.float64[N], radius: dace.float64[N]): - rmax = np.amax(radius) - res = np.zeros((npt, ), dtype=np.float64) - for i in range(npt): - # for i in dace.map[0:npt]: - r1 = rmax * i / npt - r2 = rmax * (i + 1) / npt - mask_r12 = np.logical_and((r1 <= radius), (radius < r2)) - on_values = 0 - tmp = np.float64(0) - for j in dace.map[0:N]: - if mask_r12[j]: - tmp += data[j] - on_values += 1 - res[i] = tmp / on_values - return res - - rng = np.random.default_rng(42) - SN, Snpt = 1000, 10 - data, radius = rng.random((SN, )), rng.random((SN, )) - ref = dace_azimint_naive(data, radius, npt=Snpt) - - sdfg0 = dace_azimint_naive.to_sdfg() - tree = as_schedule_tree(sdfg0) - sdfg1 = as_sdfg(tree) - val = sdfg1(data=data, radius=radius, N=SN, npt=Snpt) - - assert np.allclose(val, ref) - - -if __name__ == "__main__": - # test_map_with_tasklet_and_library() - test_azimint_naive() diff --git a/tests/schedule_tree/conversions/map_stree_test.py b/tests/schedule_tree/conversions/map_stree_test.py deleted file mode 100644 index 7cb89b01aa..0000000000 --- a/tests/schedule_tree/conversions/map_stree_test.py +++ /dev/null @@ -1,243 +0,0 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -""" Tests conversion of Map scopes from SDFG to ScheduleTree and back. """ -import dace -import numpy as np -from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree, as_sdfg - - -def test_simple_map(): - """ Tests a Map Scope with a single (non-WCR) output. """ - - M, N = (dace.symbol(s) for s in ('M', 'N')) - - @dace.program - def simple_map(A: dace.float32[M, N]): - B = np.zeros((M, N), dtype=A.dtype) - for i, j in dace.map[1:M-1, 1:N-1]: - with dace.tasklet: - c << A[i, j] - n << A[i-1, j] - s << A[i+1, j] - w << A[i, j-1] - e << A[i, j+1] - out = (c + n + s + w + e) / 5 - out >> B[i, j] - return B - - sdfg_pre = simple_map.to_sdfg() - tree = as_schedule_tree(sdfg_pre) - sdfg_post = as_sdfg(tree) - - rng = np.random.default_rng(42) - A = rng.random((20, 20), dtype=np.float32) - ref = np.zeros_like(A) - for i, j in dace.map[1:19, 1:19]: - ref[i, j] = (A[i, j] + A[i-1, j] + A[i+1, j] + A[i, j-1] + A[i, j+1]) / 5 - - val_pre = sdfg_pre(A=A, M=20, N=20) - val_post = sdfg_post(A=A, M=20, N=20) - - assert np.allclose(val_pre, ref) - assert np.allclose(val_post, ref) - - -def test_multiple_outputs_map(): - """ Tests a Map Scope with multiple (non-WCR) outputs. """ - - M, N = (dace.symbol(s) for s in ('M', 'N')) - - @dace.program - def multiple_outputs_map(A: dace.float32[M, N]): - B = np.zeros((2, M, N), dtype=A.dtype) - for i, j in dace.map[1:M-1, 1:N-1]: - with dace.tasklet: - c << A[i, j] - n << A[i-1, j] - s << A[i+1, j] - w << A[i, j-1] - e << A[i, j+1] - out0 = (c + n + s + w + e) / 5 - out1 = c / 2 + (n + s + w + e) / 2 - out0 >> B[0, i, j] - out1 >> B[1, i, j] - return B - - sdfg_pre = multiple_outputs_map.to_sdfg() - tree = as_schedule_tree(sdfg_pre) - sdfg_post = as_sdfg(tree) - - rng = np.random.default_rng(42) - A = rng.random((20, 20), dtype=np.float32) - ref = np.zeros_like(A, shape=(2, 20, 20)) - for i, j in dace.map[1:19, 1:19]: - ref[0, i, j] = (A[i, j] + A[i-1, j] + A[i+1, j] + A[i, j-1] + A[i, j+1]) / 5 - ref[1, i, j] = A[i, j] / 2 + (A[i-1, j] + A[i+1, j] + A[i, j-1] + A[i, j+1]) / 2 - - val_pre = sdfg_pre(A=A, M=20, N=20) - val_post = sdfg_post(A=A, M=20, N=20) - - assert np.allclose(val_pre, ref) - assert np.allclose(val_post, ref) - - -def test_simple_wcr_map(): - """ Tests a Map Scope with a single WCR output. """ - - M, N = (dace.symbol(s) for s in ('M', 'N')) - - @dace.program - def simple_wcr_map(A: dace.float32[M, N]): - ret = dace.float32(0) - for i, j in dace.map[1:M-1, 1:N-1]: - with dace.tasklet: - c << A[i, j] - n << A[i-1, j] - s << A[i+1, j] - w << A[i, j-1] - e << A[i, j+1] - out = (c + n + s + w + e) / 5 - out >> ret(1, lambda x, y: x + y) - return ret - - sdfg_pre = simple_wcr_map.to_sdfg() - tree = as_schedule_tree(sdfg_pre) - sdfg_post = as_sdfg(tree) - - rng = np.random.default_rng(42) - A = rng.random((20, 20), dtype=np.float32) - ref = np.float32(0) - for i, j in dace.map[1:19, 1:19]: - ref += (A[i, j] + A[i-1, j] + A[i+1, j] + A[i, j-1] + A[i, j+1]) / 5 - - val_pre = sdfg_pre(A=A, M=20, N=20)[0] - val_post = sdfg_post(A=A, M=20, N=20)[0] - - assert np.allclose(val_pre, ref) - assert np.allclose(val_post, ref) - - -def test_simple_wcr_map2(): - """ Tests a Map Scope with a single WCR output. The output is also (fake) input with WCR. """ - - M, N = (dace.symbol(s) for s in ('M', 'N')) - - @dace.program - def simple_wcr_map2(A: dace.float32[M, N]): - ret = dace.float32(0) - for i, j in dace.map[1:M-1, 1:N-1]: - with dace.tasklet: - c << A[i, j] - n << A[i-1, j] - s << A[i+1, j] - w << A[i, j-1] - e << A[i, j+1] - inp << ret(1, lambda x, y: x + y) - out = (c + n + s + w + e) / 5 - out >> ret(1, lambda x, y: x + y) - return ret - - sdfg_pre = simple_wcr_map2.to_sdfg() - tree = as_schedule_tree(sdfg_pre) - sdfg_post = as_sdfg(tree) - - rng = np.random.default_rng(42) - A = rng.random((20, 20), dtype=np.float32) - ref = np.float32(0) - for i, j in dace.map[1:19, 1:19]: - ref += (A[i, j] + A[i-1, j] + A[i+1, j] + A[i, j-1] + A[i, j+1]) / 5 - - val_pre = sdfg_pre(A=A, M=20, N=20)[0] - val_post = sdfg_post(A=A, M=20, N=20)[0] - - assert np.allclose(val_pre, ref) - assert np.allclose(val_post, ref) - - -def test_multiple_outputs_mixed_map(): - """ Tests a Map Scope with multiple (WCR and non-WCR) outputs. """ - - M, N = (dace.symbol(s) for s in ('M', 'N')) - - @dace.program - def multiple_outputs_map(A: dace.float32[M, N]): - B = np.zeros((M, N), dtype=A.dtype) - ret = np.float32(0) - for i, j in dace.map[1:M-1, 1:N-1]: - with dace.tasklet: - c << A[i, j] - n << A[i-1, j] - s << A[i+1, j] - w << A[i, j-1] - e << A[i, j+1] - out0 = (c + n + s + w + e) / 5 - out1 = out0 - out0 >> B[i, j] - out1 >> ret(1, lambda x, y: x + y) - return B, ret - - sdfg_pre = multiple_outputs_map.to_sdfg() - tree = as_schedule_tree(sdfg_pre) - sdfg_post = as_sdfg(tree) - - rng = np.random.default_rng(42) - A = rng.random((20, 20), dtype=np.float32) - ref0 = np.zeros_like(A, shape=(20, 20)) - ref1 = np.float32(0) - for i, j in dace.map[1:19, 1:19]: - ref0[i, j] = (A[i, j] + A[i-1, j] + A[i+1, j] + A[i, j-1] + A[i, j+1]) / 5 - ref1 += ref0[i, j] - - val_pre = sdfg_pre(A=A, M=20, N=20) - val_post = sdfg_post(A=A, M=20, N=20) - - assert np.allclose(val_pre[0], ref0) - assert np.allclose(val_pre[1], ref1) - assert np.allclose(val_post[0], ref0) - assert np.allclose(val_post[1], ref1) - - -def test_nested_simple_map(): - """ Tests a nested Map Scope with a single (non-WCR) output. """ - - M, N = (dace.symbol(s) for s in ('M', 'N')) - - @dace.program - def nested_simple_map(A: dace.float32[M, N]): - B = np.zeros((M, N), dtype=A.dtype) - for i, j in dace.map[1:M-2:2, 1:N-2:2]: - inA = A[i-1:i+3, j-1:j+3] - for k, l in dace.map[0:2, 0:2]: - with dace.tasklet: - c << inA[k+1, l+1] - n << inA[k, l+1] - s << inA[k+2, l+1] - w << inA[k+1, l] - e << inA[k+1, l+2] - out = (c + n + s + w + e) / 5 - out >> B[i+k, j+l] - return B - - sdfg_pre = nested_simple_map.to_sdfg() - tree = as_schedule_tree(sdfg_pre) - sdfg_post = as_sdfg(tree) - - rng = np.random.default_rng(42) - A = rng.random((20, 20), dtype=np.float32) - ref = np.zeros_like(A) - for i, j in dace.map[1:19, 1:19]: - ref[i, j] = (A[i, j] + A[i-1, j] + A[i+1, j] + A[i, j-1] + A[i, j+1]) / 5 - - val_pre = sdfg_pre(A=A, M=20, N=20) - val_post = sdfg_post(A=A, M=20, N=20) - - assert np.allclose(val_pre, ref) - assert np.allclose(val_post, ref) - - -if __name__ == "__main__": - test_simple_map() - test_multiple_outputs_map() - test_simple_wcr_map() - test_simple_wcr_map2() - test_multiple_outputs_mixed_map() - test_nested_simple_map() diff --git a/tests/schedule_tree/conversions/tasklet_test.py b/tests/schedule_tree/conversions/tasklet_test.py deleted file mode 100644 index bc527fa05c..0000000000 --- a/tests/schedule_tree/conversions/tasklet_test.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -""" Tests conversion of Tasklets from SDFG to ScheduleTree and back. """ -import dace -import numpy as np -from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree, as_sdfg - - -def test_simple_tasklet(): - """ Tests a Tasklet with a single (non-WCR) output. """ - - @dace.program - def simple_tasklet(A: dace.float32[3, 3]): - ret = dace.float32(0) - with dace.tasklet: - c << A[1, 1] - n << A[0, 1] - s << A[2, 1] - w << A[1, 0] - e << A[1, 2] - out = (c + n + s + w + e) / 5 - out >> ret - return ret - - sdfg_pre = simple_tasklet.to_sdfg() - tree = as_schedule_tree(sdfg_pre) - sdfg_post = as_sdfg(tree) - - rng = np.random.default_rng(42) - A = rng.random((3, 3), dtype=np.float32) - ref = (A[1, 1] + A[0, 1] + A[2, 1] + A[1, 0] + A[1, 2]) / 5 - - val_pre = sdfg_pre(A=A)[0] - val_post = sdfg_post(A=A)[0] - - assert np.allclose(val_pre, ref) - assert np.allclose(val_post, ref) - - -def test_multiple_outputs_tasklet(): - """ Tests a Tasklet with multiple (non-WCR) outputs. """ - - @dace.program - def multiple_outputs_tasklet(A: dace.float32[3, 3]): - ret = np.empty((2,), dtype=np.float32) - with dace.tasklet: - c << A[1, 1] - n << A[0, 1] - s << A[2, 1] - w << A[1, 0] - e << A[1, 2] - out0 = (c + n + s + w + e) / 5 - out1 = c / 2 + (n + s + w + e) / 2 - out0 >> ret[0] - out1 >> ret[1] - return ret - - sdfg_pre = multiple_outputs_tasklet.to_sdfg() - tree = as_schedule_tree(sdfg_pre) - sdfg_post = as_sdfg(tree) - - rng = np.random.default_rng(42) - A = rng.random((3, 3), dtype=np.float32) - ref = np.empty((2,), dtype=np.float32) - ref[0] = (A[1, 1] + A[0, 1] + A[2, 1] + A[1, 0] + A[1, 2]) / 5 - ref[1] = A[1, 1] / 2 + (A[0, 1] + A[2, 1] + A[1, 0] + A[1, 2]) / 2 - - val_pre = sdfg_pre(A=A) - val_post = sdfg_post(A=A) - - assert np.allclose(val_pre, ref) - assert np.allclose(val_post, ref) - - -def test_simple_wcr_tasklet(): - """ Tests a Tasklet with a single WCR output. """ - - @dace.program - def simple_wcr_tasklet(A: dace.float32[3, 3]): - ret = dace.float32(2) - with dace.tasklet: - c << A[1, 1] - n << A[0, 1] - s << A[2, 1] - w << A[1, 0] - e << A[1, 2] - out = (c + n + s + w + e) / 5 - out >> ret(1, lambda x, y: x + y) - return ret - - sdfg_pre = simple_wcr_tasklet.to_sdfg() - tree = as_schedule_tree(sdfg_pre) - sdfg_post = as_sdfg(tree) - - rng = np.random.default_rng(42) - A = rng.random((3, 3), dtype=np.float32) - ref = 2 + (A[1, 1] + A[0, 1] + A[2, 1] + A[1, 0] + A[1, 2]) / 5 - - val_pre = sdfg_pre(A=A)[0] - val_post = sdfg_post(A=A)[0] - - assert np.allclose(val_pre, ref) - assert np.allclose(val_post, ref) - - -def test_simple_wcr_tasklet2(): - """ Tests a tasklet with a single WCR output. The output is also (fake) input with WCR. """ - - @dace.program - def simple_wcr_tasklet2(A: dace.float32[3, 3]): - ret = dace.float32(2) - with dace.tasklet: - c << A[1, 1] - n << A[0, 1] - s << A[2, 1] - w << A[1, 0] - e << A[1, 2] - inp << ret(1, lambda x, y: x + y) - out = (c + n + s + w + e) / 5 - out >> ret(1, lambda x, y: x + y) - return ret - - sdfg_pre = simple_wcr_tasklet2.to_sdfg() - tree = as_schedule_tree(sdfg_pre) - sdfg_post = as_sdfg(tree) - - rng = np.random.default_rng(42) - A = rng.random((3, 3), dtype=np.float32) - ref = 2 + (A[1, 1] + A[0, 1] + A[2, 1] + A[1, 0] + A[1, 2]) / 5 - - val_pre = sdfg_pre(A=A)[0] - val_post = sdfg_post(A=A)[0] - - assert np.allclose(val_pre, ref) - assert np.allclose(val_post, ref) - - -if __name__ == "__main__": - test_simple_tasklet() - test_multiple_outputs_tasklet() - test_simple_wcr_tasklet() - test_simple_wcr_tasklet2() diff --git a/tests/schedule_tree/passes/canonicalize_if_test.py b/tests/schedule_tree/passes/canonicalize_if_test.py deleted file mode 100644 index af4b5c10ea..0000000000 --- a/tests/schedule_tree/passes/canonicalize_if_test.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -""" Tests canonicalization of If/Elif/Else scopes. """ -import dace -import numpy as np -from dace.sdfg.analysis.schedule_tree import passes, treenodes as tnodes -from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree, as_sdfg - - -class IfCounter(tnodes.ScheduleNodeVisitor): - - if_count: int - elif_count: int - else_count: int - - def __init__(self): - self.if_count = 0 - self.elif_count = 0 - self.else_count = 0 - - def visit_IfScope(self, node: tnodes.IfScope): - self.if_count += 1 - self.generic_visit(node) - - def visit_ElifScope(self, node: tnodes.ElifScope): - self.elif_count += 1 - self.generic_visit(node) - - def visit_ElseScope(self, node: tnodes.ElseScope): - self.else_count += 1 - self.generic_visit(node) - - -def test_ifelifelse_canonicalization(): - - @dace.program - def ifelifelse(c: dace.int64): - out = 0 - if c < 0: - out = c - 1 - elif c == 0: - pass - else: - out = c % 2 - return out - - sdfg_pre = ifelifelse.to_sdfg() - tree = as_schedule_tree(sdfg_pre) - ifcounter_pre = IfCounter() - ifcounter_pre.visit(tree) - ifcount = ifcounter_pre.if_count + ifcounter_pre.elif_count + ifcounter_pre.else_count - - passes.canonicalize_if(tree) - ifcounter_post = IfCounter() - ifcounter_post.visit(tree) - assert ifcounter_post.if_count == ifcount - assert ifcounter_post.elif_count == 0 - assert ifcounter_post.else_count == 0 - - sdfg_post = as_sdfg(tree) - - for c in (-100, 0, 100): - ref = ifelifelse.f(c) - val0 = sdfg_pre(c=c) - val1 = sdfg_post(c=c) - assert val0[0] == ref - assert val1[0] == ref - - -def test_ifelifelse_canonicalization2(): - - @dace.program - def ifelifelse2(c: dace.int64): - out = 0 - if c < 0: - if c < -100: - out = c + 1 - elif c < -50: - out = c + 2 - else: - out = c + 3 - elif c == 0: - pass - else: - if c > 100: - out = c % 2 - elif c > 50: - out = c % 3 - else: - out = c % 4 - return out - - sdfg_pre = ifelifelse2.to_sdfg() - tree = as_schedule_tree(sdfg_pre) - ifcounter_pre = IfCounter() - ifcounter_pre.visit(tree) - ifcount = ifcounter_pre.if_count + ifcounter_pre.elif_count + ifcounter_pre.else_count - - passes.canonicalize_if(tree) - ifcounter_post = IfCounter() - ifcounter_post.visit(tree) - assert ifcounter_post.if_count == ifcount - assert ifcounter_post.elif_count == 0 - assert ifcounter_post.else_count == 0 - - sdfg_post = as_sdfg(tree) - - for c in (-200, -70, -20, 0, 15, 67, 122): - ref = ifelifelse2.f(c) - val0 = sdfg_pre(c=c) - val1 = sdfg_post(c=c) - assert val0[0] == ref - assert val1[0] == ref - - -if __name__ == "__main__": - test_ifelifelse_canonicalization() - test_ifelifelse_canonicalization2() diff --git a/tests/schedule_tree/transformations/if_fission_test.py b/tests/schedule_tree/transformations/if_fission_test.py deleted file mode 100644 index 717d6a2469..0000000000 --- a/tests/schedule_tree/transformations/if_fission_test.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -""" Tests fission of If scopes. """ -import dace -import numpy as np -from dace.sdfg.analysis.schedule_tree import passes, transformations as ttrans, treenodes as tnodes -from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree, as_sdfg - - -class IfCounter(tnodes.ScheduleNodeVisitor): - - if_count: int - elif_count: int - else_count: int - - def __init__(self): - self.if_count = 0 - self.elif_count = 0 - self.else_count = 0 - - def visit_IfScope(self, node: tnodes.IfScope): - self.if_count += 1 - self.generic_visit(node) - - def visit_ElifScope(self, node: tnodes.ElifScope): - self.elif_count += 1 - self.generic_visit(node) - - def visit_ElseScope(self, node: tnodes.ElseScope): - self.else_count += 1 - self.generic_visit(node) - - -def test_if_fission(): - - @dace.program - def ifelifelse(c: dace.int64): - out0, out1 = 0, 0 - if c < 0: - out0 = -5 - out1 = -10 - if c == 0: - pass - if c > 0: - out0 = 5 - out1 = 10 - return out0, out1 - - sdfg_pre = ifelifelse.to_sdfg() - tree = as_schedule_tree(sdfg_pre) - ifcounter_pre = IfCounter() - ifcounter_pre.visit(tree) - if ifcounter_pre.elif_count > 0 or ifcounter_pre.else_count > 0: - passes.canonicalize_if(tree) - ifcounter_post = IfCounter() - ifcounter_post.visit(tree) - assert ifcounter_post.elif_count == 0 - assert ifcounter_post.else_count == 0 - - for child in list(tree.children): - if isinstance(child, tnodes.IfScope): - ttrans.if_fission(child) - - sdfg_post = as_sdfg(tree) - - for c in (-100, 0, 100): - ref = ifelifelse.f(c) - val0 = sdfg_pre(c=c) - val1 = sdfg_post(c=c) - assert val0 == ref - assert val1 == ref - - -if __name__ == "__main__": - test_if_fission() From 3c834892aeab6b90820e15880d9e3a826c005c98 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 20 Jul 2023 17:24:18 -0700 Subject: [PATCH 70/98] Add set of tests --- .../analysis/schedule_tree/sdfg_to_tree.py | 4 +- dace/sdfg/analysis/schedule_tree/treenodes.py | 2 +- tests/schedule_tree/nesting_test.py | 182 ++++++++++++++++++ 3 files changed, 186 insertions(+), 2 deletions(-) create mode 100644 tests/schedule_tree/nesting_test.py diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 72df0a0472..e2b04fd124 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -114,7 +114,9 @@ def normalize_memlet(sdfg: SDFG, state: SDFGState, original: gr.MultiConnectorEd :param data: The data descriptor. :return: A new memlet. """ - edge = copy.deepcopy(original) + # Shallow copy edge + edge = gr.MultiConnectorEdge(original.src, original.src_conn, original.dst, original.dst_conn, + copy.deepcopy(original.data), original.key) edge.data.try_initialize(sdfg, state, edge) if edge.data.data == data: diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 5d200efc11..b76ef24edc 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -1,4 +1,4 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. import copy from dataclasses import dataclass, field from dace import nodes, data, subsets diff --git a/tests/schedule_tree/nesting_test.py b/tests/schedule_tree/nesting_test.py new file mode 100644 index 0000000000..fba09bb441 --- /dev/null +++ b/tests/schedule_tree/nesting_test.py @@ -0,0 +1,182 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" +Nesting and dealiasing tests for schedule trees. +""" +import dace +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree +from dace.transformation.dataflow import RemoveSliceView + +import pytest +from typing import List + +N = dace.symbol('N') +T = dace.symbol('T') + + +def test_stree_mpath_multiscope(): + + @dace.program + def tester(A: dace.float64[N, N]): + for i in dace.map[0:N:T]: + for j, k in dace.map[0:T, 0:N]: + for l in dace.map[0:T]: + A[i + j, k + l] = 1 + + # The test should generate different SDFGs for different simplify configurations, + # but the same schedule tree + stree = as_schedule_tree(tester.to_sdfg()) + assert [type(n) for n in stree.preorder_traversal()][1:] == [tn.MapScope, tn.MapScope, tn.MapScope, tn.TaskletNode] + + +def test_stree_mpath_multiscope_dependent(): + + @dace.program + def tester(A: dace.float64[N, N]): + for i in dace.map[0:N:T]: + for j, k in dace.map[0:T, 0:N]: + for l in dace.map[0:k]: + A[i + j, l] = 1 + + # The test should generate different SDFGs for different simplify configurations, + # but the same schedule tree + stree = as_schedule_tree(tester.to_sdfg()) + assert [type(n) for n in stree.preorder_traversal()][1:] == [tn.MapScope, tn.MapScope, tn.MapScope, tn.TaskletNode] + + +def test_stree_mpath_nested(): + + @dace.program + def nester(A, i, k, j): + for l in range(k): + A[i + j, l] = 1 + + @dace.program + def tester(A: dace.float64[N, N]): + for i in dace.map[0:N:T]: + for j, k in dace.map[0:T, 0:N]: + nester(A, i, j, k) + + stree = as_schedule_tree(tester.to_sdfg()) + + # Simplifying yields a different SDFG due to scalars and symbols, so testing is slightly different + simplified = dace.Config.get_bool('optimizer', 'automatic_simplification') + + if simplified: + assert [type(n) + for n in stree.preorder_traversal()][1:] == [tn.MapScope, tn.MapScope, tn.ForScope, tn.TaskletNode] + + tasklet: tn.TaskletNode = list(stree.preorder_traversal())[-1] + + if simplified: + assert str(next(iter(tasklet.out_memlets.values()))) == 'A[i + k, l]' + else: + assert str(next(iter(tasklet.out_memlets.values()))).endswith(', l]') + + +@pytest.mark.parametrize('dst_subset', (False, True)) +def test_stree_copy_same_scope(dst_subset): + sdfg = dace.SDFG('tester') + sdfg.add_array('A', [3 * N], dace.float64) + sdfg.add_array('B', [3 * N], dace.float64) + state = sdfg.add_state() + + r = state.add_read('A') + w = state.add_write('B') + if not dst_subset: + state.add_nedge(r, w, dace.Memlet(data='A', subset='2*N:3*N', other_subset='N:2*N')) + else: + state.add_nedge(r, w, dace.Memlet(data='B', subset='N:2*N', other_subset='2*N:3*N')) + + stree = as_schedule_tree(sdfg) + assert len(stree.children) == 1 and isinstance(stree.children[0], tn.CopyNode) + assert stree.children[0].target == 'B' + assert stree.children[0].as_string() == 'B[N:2*N] = copy A[2*N:3*N]' + + +@pytest.mark.parametrize('dst_subset', (False, True)) +def test_stree_copy_different_scope(dst_subset): + sdfg = dace.SDFG('tester') + sdfg.add_array('A', [3 * N], dace.float64) + sdfg.add_array('B', [3 * N], dace.float64) + state = sdfg.add_state() + + r = state.add_read('A') + w = state.add_write('B') + me, mx = state.add_map('something', dict(i='0:1')) + if not dst_subset: + state.add_memlet_path(r, me, w, memlet=dace.Memlet(data='A', subset='2*N:3*N', other_subset='N + i:2*N + i')) + else: + state.add_memlet_path(r, me, w, memlet=dace.Memlet(data='B', subset='N + i:2*N + i', other_subset='2*N:3*N')) + state.add_nedge(w, mx, dace.Memlet()) + + stree = as_schedule_tree(sdfg) + stree_nodes = list(stree.preorder_traversal())[1:] + assert [type(n) for n in stree_nodes] == [tn.MapScope, tn.CopyNode] + assert stree_nodes[-1].target == 'B' + assert stree_nodes[-1].as_string() == 'B[N + i:2*N + i] = copy A[2*N:3*N]' + + +def test_dealias_nested_call(): + + @dace.program + def nester(a, b): + b[:] = a + + @dace.program + def tester(a: dace.float64[40]): + nester(a[1:21], a[10:30]) + + sdfg = tester.to_sdfg(simplify=False) + sdfg.apply_transformations_repeated(RemoveSliceView) + + stree = as_schedule_tree(sdfg) + assert len(stree.children) == 1 + copy = stree.children[0] + assert isinstance(copy, tn.CopyNode) + assert copy.target == 'a' + assert copy.memlet.data == 'a' + assert str(copy.memlet.src_subset) == '10:30' + assert str(copy.memlet.dst_subset) == '1:21' + + +def test_dealias_memlet_composition(): + + def nester2(c): + c[2] = 1 + + def nester1(b): + nester2(b[-5:]) + + @dace.program + def tester(a: dace.float64[N, N]): + nester1(a[:, 1]) + + # Simplifying yields a different SDFG due to views, so testing is slightly different + simplified = dace.Config.get_bool('optimizer', 'automatic_simplification') + + sdfg = tester.to_sdfg() + stree = as_schedule_tree(sdfg) + if simplified: + assert len(stree.children) == 1 + tasklet = stree.children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert str(next(iter(tasklet.out_memlets.values()))) == 'a[N - 3, 1]' + else: + print(stree.as_string()) + assert len(stree.children) == 3 + # TODO: Should views precede tasklet? + stree_nodes = list(stree.preorder_traversal())[1:] + assert [type(n) for n in stree_nodes] == [tn.TaskletNode, tn.ViewNode, tn.ViewNode] + + +if __name__ == '__main__': + test_stree_mpath_multiscope() + test_stree_mpath_multiscope_dependent() + test_stree_mpath_nested() + test_stree_copy_same_scope(False) + test_stree_copy_same_scope(True) + test_stree_copy_different_scope(False) + test_stree_copy_different_scope(True) + test_dealias_nested_call() + test_dealias_memlet_composition() From 298457dfa80c39e8b8f000d2377017b7beb7c6f8 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 21 Jul 2023 07:51:20 -0700 Subject: [PATCH 71/98] Add another dealiasing test --- tests/schedule_tree/nesting_test.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/tests/schedule_tree/nesting_test.py b/tests/schedule_tree/nesting_test.py index fba09bb441..f5d71d7b5c 100644 --- a/tests/schedule_tree/nesting_test.py +++ b/tests/schedule_tree/nesting_test.py @@ -163,13 +163,38 @@ def tester(a: dace.float64[N, N]): assert isinstance(tasklet, tn.TaskletNode) assert str(next(iter(tasklet.out_memlets.values()))) == 'a[N - 3, 1]' else: - print(stree.as_string()) assert len(stree.children) == 3 # TODO: Should views precede tasklet? stree_nodes = list(stree.preorder_traversal())[1:] assert [type(n) for n in stree_nodes] == [tn.TaskletNode, tn.ViewNode, tn.ViewNode] +def test_dealias_interstate_edge(): + sdfg = dace.SDFG('tester') + sdfg.add_array('A', [20], dace.float64) + sdfg.add_array('B', [20], dace.float64) + + nsdfg = dace.SDFG('nester') + nsdfg.add_array('A', [19], dace.float64) + nsdfg.add_array('B', [15], dace.float64) + nsdfg.add_symbol('m', dace.float64) + nstate1 = nsdfg.add_state() + nstate2 = nsdfg.add_state() + nsdfg.add_edge(nstate1, nstate2, dace.InterstateEdge(condition='B[1] > 0', assignments=dict(m='A[2]'))) + + # Connect to nested SDFG both with flipped definitions and offset memlets + state = sdfg.add_state() + nsdfg_node = state.add_nested_sdfg(nsdfg, None, {'A', 'B'}, {}) + ra = state.add_read('A') + rb = state.add_read('B') + state.add_edge(ra, None, nsdfg_node, 'B', dace.Memlet('A[1:20]')) + state.add_edge(rb, None, nsdfg_node, 'A', dace.Memlet('B[2:17]')) + + sdfg.validate() + stree = as_schedule_tree(sdfg) + nodes = list(stree.preorder_traversal())[1:] + + if __name__ == '__main__': test_stree_mpath_multiscope() test_stree_mpath_multiscope_dependent() @@ -180,3 +205,4 @@ def tester(a: dace.float64[N, N]): test_stree_copy_different_scope(True) test_dealias_nested_call() test_dealias_memlet_composition() + test_dealias_interstate_edge() From 1812174eb9711771d877fa1adeaa6744efac3406 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 21 Jul 2023 07:51:30 -0700 Subject: [PATCH 72/98] Fix stateif printout --- dace/sdfg/analysis/schedule_tree/treenodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index b76ef24edc..933b30736f 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -181,7 +181,7 @@ class StateIfScope(IfScope): def as_string(self, indent: int = 0): result = indent * INDENTATION + f'stateif {self.condition.as_string}:\n' - return result + super().as_string(indent) + return result + super(IfScope, self).as_string(indent) @dataclass From 62789ed9781ae1d368b3ee573e372295c5b8cb58 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 21 Jul 2023 07:55:52 -0700 Subject: [PATCH 73/98] Fix reference set alignment --- dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py | 2 +- tests/schedule_tree/schedule_test.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index e2b04fd124..4473e74348 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -282,7 +282,7 @@ def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[ # 2. Check for reference sets if isinstance(e.dst, dace.nodes.AccessNode) and e.dst_conn == 'set': assert isinstance(e.dst.desc(sdfg), dace.data.Reference) - result[e] = tn.RefSetNode(target=e.data.data, + result[e] = tn.RefSetNode(target=e.dst.data, memlet=e.data, src_desc=sdfg.arrays[e.data.data], ref_desc=sdfg.arrays[e.dst.data]) diff --git a/tests/schedule_tree/schedule_test.py b/tests/schedule_tree/schedule_test.py index 9827c3cd50..ada70cf822 100644 --- a/tests/schedule_tree/schedule_test.py +++ b/tests/schedule_tree/schedule_test.py @@ -201,8 +201,11 @@ def test_reference(): s2.add_edge(s2.add_access('B'), None, s2.add_access('ref'), 'set', dace.Memlet('B[0:20]')) end.add_nedge(end.add_access('ref'), end.add_access('C'), dace.Memlet('ref[0:20]')) - # TODO: Align reference memlet - # print(as_schedule_tree(sdfg).as_string()) + stree = as_schedule_tree(sdfg) + nodes = list(stree.preorder_traversal())[1:] + assert [type(n) for n in nodes] == [tn.IfScope, tn.RefSetNode, tn.ElseScope, tn.RefSetNode, tn.CopyNode] + assert nodes[1].as_string() == 'ref = refset to A[0:20]' + assert nodes[3].as_string() == 'ref = refset to B[0:20]' def test_code_to_code(): From 42bc759a7862bfc21000ecbe7207880ff3bbc44e Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 21 Jul 2023 07:58:05 -0700 Subject: [PATCH 74/98] Fix printout of tasklets without outputs --- dace/sdfg/analysis/schedule_tree/treenodes.py | 2 ++ tests/schedule_tree/schedule_test.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 933b30736f..da56dc16aa 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -273,6 +273,8 @@ class TaskletNode(ScheduleTreeNode): def as_string(self, indent: int = 0): in_memlets = ', '.join(f'{v}' for v in self.in_memlets.values()) out_memlets = ', '.join(f'{v}' for v in self.out_memlets.values()) + if not out_memlets: + return indent * INDENTATION + f'tasklet({in_memlets})' return indent * INDENTATION + f'{out_memlets} = tasklet({in_memlets})' diff --git a/tests/schedule_tree/schedule_test.py b/tests/schedule_tree/schedule_test.py index ada70cf822..4eabf50497 100644 --- a/tests/schedule_tree/schedule_test.py +++ b/tests/schedule_tree/schedule_test.py @@ -216,10 +216,10 @@ def test_code_to_code(): t2 = state.add_tasklet('b', {'inp'}, {}, 'print(inp)', side_effects=True) state.add_edge(t1, 'out', t2, 'inp', dace.Memlet('scal')) - # TODO: Nicely print tasklets without outputs stree = as_schedule_tree(sdfg) assert len(stree.children) == 2 assert all(isinstance(c, tn.TaskletNode) for c in stree.children) + assert stree.children[1].as_string().startswith('tasklet(scal') def test_dyn_map_range(): From 4b43606395ade5421afc71c3702355ae7f625d5e Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 21 Jul 2023 09:06:38 -0700 Subject: [PATCH 75/98] Remove test --- .../schedule_tree/map_fission_test.py | 321 ------------------ 1 file changed, 321 deletions(-) delete mode 100644 tests/transformations/schedule_tree/map_fission_test.py diff --git a/tests/transformations/schedule_tree/map_fission_test.py b/tests/transformations/schedule_tree/map_fission_test.py deleted file mode 100644 index 2f12c7721c..0000000000 --- a/tests/transformations/schedule_tree/map_fission_test.py +++ /dev/null @@ -1,321 +0,0 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. -import dace -import numpy as np -from dace import nodes -from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree, as_sdfg -from dace.sdfg.analysis.schedule_tree.transformations import map_fission -from dace.transformation.helpers import nest_state_subgraph -from tests.transformations.mapfission_test import mapfission_sdfg - - -def test_map_with_tasklet_and_library(): - - N = dace.symbol('N') - @dace.program - def map_with_tasklet_and_library(A: dace.float32[N, 5, 5], B: dace.float32[N, 5, 5], cst: dace.int32): - out = np.ndarray((N, 5, 5), dtype=dace.float32) - for i in dace.map[0:N]: - out[i] = cst * (A[i] @ B[i]) - return out - - rng = np.random.default_rng(42) - A = rng.random((10, 5, 5), dtype=np.float32) - B = rng.random((10, 5, 5), dtype=np.float32) - cst = rng.integers(0, 100, dtype=np.int32) - ref = cst * (A @ B) - - val0 = map_with_tasklet_and_library(A, B, cst) - sdfg0 = map_with_tasklet_and_library.to_sdfg() - tree = as_schedule_tree(sdfg0) - result = map_fission(tree.children[0], tree) - assert result - sdfg1 = as_sdfg(tree) - val1 = sdfg1(A=A, B=B, cst=cst, N=A.shape[0]) - - assert np.allclose(val0, ref) - assert np.allclose(val1, ref) - - -def test_subgraph(): - - rng = np.random.default_rng(42) - A = rng.random((4, )) - ref = np.zeros([2], dtype=np.float64) - ref[0] = (A[0] + A[1]) + (A[0] * 2 * A[1] * 2) + (A[0] * 3) + 5.0 - ref[1] = (A[2] + A[3]) + (A[2] * 2 * A[3] * 2) + (A[2] * 3) + 5.0 - val = np.empty((2, )) - - sdfg0 = mapfission_sdfg() - tree = as_schedule_tree(sdfg0) - pcode, _ = tree.as_python() - print(pcode) - result = map_fission(tree.children[0], tree) - assert result - pcode, _ = tree.as_python() - print(pcode) - sdfg1 = as_sdfg(tree) - sdfg1(A=A, B=val) - - assert np.allclose(val, ref) - - -def test_nested_subgraph(): - - rng = np.random.default_rng(42) - A = rng.random((4, )) - ref = np.zeros([2], dtype=np.float64) - ref[0] = (A[0] + A[1]) + (A[0] * 2 * A[1] * 2) + (A[0] * 3) + 5.0 - ref[1] = (A[2] + A[3]) + (A[2] * 2 * A[3] * 2) + (A[2] * 3) + 5.0 - val = np.empty((2, )) - - sdfg0 = mapfission_sdfg() - state = sdfg0.nodes()[0] - topmap = next(node for node in state.nodes() if isinstance(node, nodes.MapEntry) and node.label == 'outer') - subgraph = state.scope_subgraph(topmap, include_entry=False, include_exit=False) - nest_state_subgraph(sdfg0, state, subgraph) - tree = as_schedule_tree(sdfg0) - result = map_fission(tree.children[0], tree) - assert result - pcode, _ = tree.as_python() - print(pcode) - sdfg1 = as_sdfg(tree) - sdfg1(A=A, B=val) - - assert np.allclose(val, ref) - - -def test_nested_transient(): - """ Test nested SDFGs with transients. """ - - # Inner SDFG - nsdfg = dace.SDFG('nested') - nsdfg.add_array('a', [1], dace.float64) - nsdfg.add_array('b', [1], dace.float64) - nsdfg.add_transient('t', [1], dace.float64) - - # a->t state - nstate = nsdfg.add_state() - irnode = nstate.add_read('a') - task = nstate.add_tasklet('t1', {'inp'}, {'out'}, 'out = 2*inp') - iwnode = nstate.add_write('t') - nstate.add_edge(irnode, None, task, 'inp', dace.Memlet.simple('a', '0')) - nstate.add_edge(task, 'out', iwnode, None, dace.Memlet.simple('t', '0')) - - # t->a state - first_state = nstate - nstate = nsdfg.add_state() - irnode = nstate.add_read('t') - task = nstate.add_tasklet('t2', {'inp'}, {'out'}, 'out = 3*inp') - iwnode = nstate.add_write('b') - nstate.add_edge(irnode, None, task, 'inp', dace.Memlet.simple('t', '0')) - nstate.add_edge(task, 'out', iwnode, None, dace.Memlet.simple('b', '0')) - - nsdfg.add_edge(first_state, nstate, dace.InterstateEdge()) - - # Outer SDFG - sdfg = dace.SDFG('nested_transient_fission') - sdfg.add_array('A', [2], dace.float64) - state = sdfg.add_state() - rnode = state.add_read('A') - wnode = state.add_write('A') - me, mx = state.add_map('outer', dict(i='0:2')) - nsdfg_node = state.add_nested_sdfg(nsdfg, None, {'a'}, {'b'}) - state.add_memlet_path(rnode, me, nsdfg_node, dst_conn='a', memlet=dace.Memlet.simple('A', 'i')) - state.add_memlet_path(nsdfg_node, mx, wnode, src_conn='b', memlet=dace.Memlet.simple('A', 'i')) - - # self.assertGreater(sdfg.apply_transformations_repeated(MapFission), 0) - tree = as_schedule_tree(sdfg) - result = map_fission(tree.children[0], tree) - assert result - sdfg = as_sdfg(tree) - - # Test - A = np.random.rand(2) - expected = A * 6 - sdfg(A=A) - # self.assertTrue(np.allclose(A, expected)) - assert np.allclose(A, expected) - - -def test_inputs_outputs(): - """ - Test subgraphs where the computation modules that are in the middle - connect to the outside. - """ - - sdfg = dace.SDFG('inputs_outputs_fission') - sdfg.add_array('in1', [2], dace.float64) - sdfg.add_array('in2', [2], dace.float64) - sdfg.add_scalar('tmp', dace.float64, transient=True) - sdfg.add_array('out1', [2], dace.float64) - sdfg.add_array('out2', [2], dace.float64) - state = sdfg.add_state() - in1 = state.add_read('in1') - in2 = state.add_read('in2') - out1 = state.add_write('out1') - out2 = state.add_write('out2') - me, mx = state.add_map('outer', dict(i='0:2')) - t1 = state.add_tasklet('t1', {'i1'}, {'o1', 'o2'}, 'o1 = i1 * 2; o2 = i1 * 5') - t2 = state.add_tasklet('t2', {'i1', 'i2'}, {'o1'}, 'o1 = i1 * i2') - state.add_memlet_path(in1, me, t1, dst_conn='i1', memlet=dace.Memlet.simple('in1', 'i')) - state.add_memlet_path(in2, me, t2, dst_conn='i2', memlet=dace.Memlet.simple('in2', 'i')) - state.add_edge(t1, 'o1', t2, 'i1', dace.Memlet.simple('tmp', '0')) - state.add_memlet_path(t2, mx, out1, src_conn='o1', memlet=dace.Memlet.simple('out1', 'i')) - state.add_memlet_path(t1, mx, out2, src_conn='o2', memlet=dace.Memlet.simple('out2', 'i')) - - # self.assertGreater(sdfg.apply_transformations(MapFission), 0) - tree = as_schedule_tree(sdfg) - result = map_fission(tree.children[0], tree) - assert result - sdfg = as_sdfg(tree) - - # Test - A, B, C, D = tuple(np.random.rand(2) for _ in range(4)) - expected_C = (A * 2) * B - expected_D = A * 5 - sdfg(in1=A, in2=B, out1=C, out2=D) - # self.assertTrue(np.allclose(C, expected_C)) - # self.assertTrue(np.allclose(D, expected_D)) - assert np.allclose(C, expected_C) - assert np.allclose(D, expected_D) - - -def test_offsets(): - sdfg = dace.SDFG('mapfission_offsets') - sdfg.add_array('A', [20], dace.float64) - sdfg.add_scalar('interim', dace.float64, transient=True) - state = sdfg.add_state() - me, mx = state.add_map('outer', dict(i='10:20')) - - t1 = state.add_tasklet('addone', {'a'}, {'b'}, 'b = a + 1') - t2 = state.add_tasklet('addtwo', {'a'}, {'b'}, 'b = a + 2') - - aread = state.add_read('A') - awrite = state.add_write('A') - state.add_memlet_path(aread, me, t1, dst_conn='a', memlet=dace.Memlet.simple('A', 'i')) - state.add_edge(t1, 'b', t2, 'a', dace.Memlet.simple('interim', '0')) - state.add_memlet_path(t2, mx, awrite, src_conn='b', memlet=dace.Memlet.simple('A', 'i')) - - # self.assertGreater(sdfg.apply_transformations(MapFission), 0) - tree = as_schedule_tree(sdfg) - pcode, _ = tree.as_python() - print(pcode) - result = map_fission(tree.children[0], tree) - assert result - pcode, _ = tree.as_python() - print(pcode) - sdfg = as_sdfg(tree) - - # dace.propagate_memlets_sdfg(sdfg) - # sdfg.validate() - - # Test - A = np.random.rand(20) - expected = A.copy() - expected[10:] += 3 - sdfg(A=A) - # self.assertTrue(np.allclose(A, expected)) - assert np.allclose(A, expected) - - -def test_offsets_array(): - sdfg = dace.SDFG('mapfission_offsets2') - sdfg.add_array('A', [20], dace.float64) - sdfg.add_array('interim', [1], dace.float64, transient=True) - state = sdfg.add_state() - me, mx = state.add_map('outer', dict(i='10:20')) - - t1 = state.add_tasklet('addone', {'a'}, {'b'}, 'b = a + 1') - interim = state.add_access('interim') - t2 = state.add_tasklet('addtwo', {'a'}, {'b'}, 'b = a + 2') - - aread = state.add_read('A') - awrite = state.add_write('A') - state.add_memlet_path(aread, me, t1, dst_conn='a', memlet=dace.Memlet.simple('A', 'i')) - state.add_edge(t1, 'b', interim, None, dace.Memlet.simple('interim', '0')) - state.add_edge(interim, None, t2, 'a', dace.Memlet.simple('interim', '0')) - state.add_memlet_path(t2, mx, awrite, src_conn='b', memlet=dace.Memlet.simple('A', 'i')) - - # self.assertGreater(sdfg.apply_transformations(MapFission), 0) - tree = as_schedule_tree(sdfg) - result = map_fission(tree.children[0], tree) - assert result - sdfg = as_sdfg(tree) - - # dace.propagate_memlets_sdfg(sdfg) - # sdfg.validate() - - # Test - A = np.random.rand(20) - expected = A.copy() - expected[10:] += 3 - sdfg(A=A) - # self.assertTrue(np.allclose(A, expected)) - assert np.allclose(A, expected) - - -def test_mapfission_with_symbols(): - ''' - Tests MapFission in the case of a Map containing a NestedSDFG that is using some symbol from the top-level SDFG - missing from the NestedSDFG's symbol mapping. Please note that this is an unusual case that is difficult to - reproduce and ultimately unrelated to MapFission. Consider solving the underlying issue and then deleting this - test and the corresponding (obsolete) code in MapFission. - ''' - - M, N = dace.symbol('M'), dace.symbol('N') - - sdfg = dace.SDFG('tasklet_code_with_symbols') - sdfg.add_array('A', (M, N), dace.int32) - sdfg.add_array('B', (M, N), dace.int32) - - state = sdfg.add_state('parent', is_start_state=True) - me, mx = state.add_map('parent_map', {'i': '0:N'}) - - nsdfg = dace.SDFG('nested_sdfg') - nsdfg.add_scalar('inner_A', dace.int32) - nsdfg.add_scalar('inner_B', dace.int32) - - nstate = nsdfg.add_state('child', is_start_state=True) - na = nstate.add_access('inner_A') - nb = nstate.add_access('inner_B') - ta = nstate.add_tasklet('tasklet_A', {}, {'__out'}, '__out = M') - tb = nstate.add_tasklet('tasklet_B', {}, {'__out'}, '__out = M') - nstate.add_edge(ta, '__out', na, None, dace.Memlet.from_array('inner_A', nsdfg.arrays['inner_A'])) - nstate.add_edge(tb, '__out', nb, None, dace.Memlet.from_array('inner_B', nsdfg.arrays['inner_B'])) - - a = state.add_access('A') - b = state.add_access('B') - t = nodes.NestedSDFG('child_sdfg', nsdfg, {}, {'inner_A', 'inner_B'}, {}) - nsdfg.parent = state - nsdfg.parent_sdfg = sdfg - nsdfg.parent_nsdfg_node = t - state.add_node(t) - state.add_nedge(me, t, dace.Memlet()) - state.add_memlet_path(t, mx, a, memlet=dace.Memlet('A[0, i]'), src_conn='inner_A') - state.add_memlet_path(t, mx, b, memlet=dace.Memlet('B[0, i]'), src_conn='inner_B') - - # num = sdfg.apply_transformations_repeated(MapFission) - tree = as_schedule_tree(sdfg) - result = map_fission(tree.children[0], tree) - assert result - sdfg = as_sdfg(tree) - - A = np.ndarray((2, 10), dtype=np.int32) - B = np.ndarray((2, 10), dtype=np.int32) - sdfg(A=A, B=B, M=2, N=10) - - ref = np.full((10, ), fill_value=2, dtype=np.int32) - - assert np.array_equal(A[0], ref) - assert np.array_equal(B[0], ref) - - -if __name__ == "__main__": - test_map_with_tasklet_and_library() - test_subgraph() - test_nested_subgraph() - test_nested_transient() - test_inputs_outputs() - test_offsets() - test_offsets_array() - test_mapfission_with_symbols() From b4a7984ec94ff1b5eb00cb474bc7d1fb766fc425 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 25 Jul 2023 23:07:03 -0700 Subject: [PATCH 76/98] Schedule tree: fix tests, print empty scopes in a nicer way --- dace/sdfg/analysis/schedule_tree/treenodes.py | 2 ++ tests/schedule_tree/schedule_test.py | 14 +++++++++----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index da56dc16aa..b96be06832 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -55,6 +55,8 @@ def __init__(self, child.parent = self def as_string(self, indent: int = 0): + if not self.children: + return (indent + 1) * INDENTATION + 'pass' return '\n'.join([child.as_string(indent + 1) for child in self.children]) def preorder_traversal(self) -> Iterator['ScheduleTreeNode']: diff --git a/tests/schedule_tree/schedule_test.py b/tests/schedule_tree/schedule_test.py index 4eabf50497..6d41420856 100644 --- a/tests/schedule_tree/schedule_test.py +++ b/tests/schedule_tree/schedule_test.py @@ -73,7 +73,7 @@ def main(a: dace.float64[20, 10]): nest1(a[10:15]) nest1(a[15:]) - sdfg = main.to_sdfg() + sdfg = main.to_sdfg(simplify=True) stree = as_schedule_tree(sdfg) # Despite two levels of nesting, immediate children are the 4 for loops @@ -150,8 +150,10 @@ def test_irreducible_sub_sdfg(): # Add a loop following general block sdfg.add_loop(e, sdfg.add_state(), None, 'i', '0', 'i < 10', 'i + 1') - # TODO: Missing exit in stateif s2->e - # print(as_schedule_tree(sdfg).as_string()) + stree = as_schedule_tree(sdfg) + node_types = [type(n) for n in stree.preorder_traversal()] + assert node_types.count(tn.GBlock) == 1 # Only one gblock + assert node_types[-1] == tn.ForScope # Check that loop was detected def test_irreducible_in_loops(): @@ -176,8 +178,10 @@ def test_irreducible_in_loops(): # Avoiding undefined behavior sdfg.edges_between(l3, l4)[0].data.condition.as_string = 'i >= 5' - # TODO: gblock must cover the greatest common scope its labels are in. - # print(as_schedule_tree(sdfg).as_string()) + stree = as_schedule_tree(sdfg) + node_types = [type(n) for n in stree.preorder_traversal()] + assert node_types.count(tn.GBlock) == 1 + assert node_types.count(tn.ForScope) == 2 def test_reference(): From a4fe54273cd193f287855b130955783a729efad2 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 27 Jul 2023 00:03:54 -0700 Subject: [PATCH 77/98] Fix name collection in collision removal --- dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 4473e74348..9220274530 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -179,16 +179,16 @@ def remove_name_collisions(sdfg: SDFG): parent_node = nsdfg.parent_nsdfg_node # Preserve top-level SDFG names + do_not_replace = False if not parent_node: - continue + do_not_replace = True # Rename duplicate data containers for name, desc in nsdfg.arrays.items(): - - if not desc.transient: - continue - if name in identifiers_seen: + if not desc.transient or do_not_replace: + continue + new_name = data.find_new_name(name, identifiers_seen) replacements[name] = new_name name = new_name @@ -200,7 +200,7 @@ def remove_name_collisions(sdfg: SDFG): if parent_node is not None and name in parent_node.symbol_mapping: continue - if name in identifiers_seen: + if name in identifiers_seen and not do_not_replace: new_name = data.find_new_name(name, identifiers_seen) replacements[name] = new_name name = new_name @@ -208,7 +208,7 @@ def remove_name_collisions(sdfg: SDFG): # Rename duplicate constants for name in nsdfg.constants_prop.keys(): - if name in identifiers_seen: + if name in identifiers_seen and not do_not_replace: new_name = data.find_new_name(name, identifiers_seen) replacements[name] = new_name name = new_name From f6e5cd3f9af1f7e3c975eeee095dc83b5a2d6109 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 27 Jul 2023 00:04:35 -0700 Subject: [PATCH 78/98] Make naming edge case test more complex --- tests/schedule_tree/naming_test.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/schedule_tree/naming_test.py b/tests/schedule_tree/naming_test.py index d2eddc4bc4..702db6a4cf 100644 --- a/tests/schedule_tree/naming_test.py +++ b/tests/schedule_tree/naming_test.py @@ -148,15 +148,22 @@ def test_edgecase_symbol_mapping(): nsdfg.add_symbol('k', dace.int64) nstate = nsdfg.add_state() nstate2 = nsdfg.add_state() + nstate3 = nsdfg.add_state() nsdfg.add_edge(nstate, nstate2, dace.InterstateEdge(assignments={'k': 'M + 1'})) + nsdfg.add_edge(nstate2, nstate3, dace.InterstateEdge(assignments={'l': 'k'})) state2.add_nested_sdfg(nsdfg, None, {}, {}, {'N': 'M', 'M': 'N', 'k': 'M + 1'}) stree = as_schedule_tree(sdfg) - assert len(stree.children) == 1 + + # k is reassigned internally, so that should be preserved + assert len(stree.children) == 2 assert isinstance(stree.children[0], tn.AssignNode) - # TODO: "assign M + 1 = (N + 1)", target should stay "k" - assert str(stree.children[0].name) == 'k' + assert stree.children[0].name == 'k' + assert stree.children[0].value.as_string == '(N + 1)' + assert isinstance(stree.children[1], tn.AssignNode) + assert stree.children[1].name == 'l' + assert stree.children[1].value.as_string in ('k', '(N + 1)') def test_clash_iteration_symbols(): From 23d4df167e94b059764c199efc8a72cdbd7aa8f4 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 27 Jul 2023 00:31:31 -0700 Subject: [PATCH 79/98] Better and robust pass to replace symbol mappings --- .../analysis/schedule_tree/sdfg_to_tree.py | 31 +++++++++++++++++-- dace/sdfg/sdfg.py | 2 +- tests/schedule_tree/naming_test.py | 15 +++++---- 3 files changed, 39 insertions(+), 9 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 9220274530..11b90dbfd0 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -233,6 +233,34 @@ def _make_view_node(state: SDFGState, edge: gr.MultiConnectorEdge[Memlet], view_ view_desc=sdfg.arrays[view_name]) +def replace_symbols_until_set(nsdfg: dace.nodes.NestedSDFG): + """ + Replaces symbol values in a nested SDFG until their value has been reset. This is used for matching symbol + namespaces between an SDFG and a nested SDFG. + """ + from collections import defaultdict + from dace.transformation.passes.analysis import StateReachability + + mapping = nsdfg.symbol_mapping + sdfg = nsdfg.sdfg + reachable_states = StateReachability().apply_pass(sdfg, {})[sdfg.sdfg_id] + redefined_symbols: Dict[SDFGState, Set[str]] = defaultdict(set) + + # Collect redefined symbols + for e in sdfg.edges(): + redefined = e.data.assignments.keys() + redefined_symbols[e.dst] |= redefined + for reachable in reachable_states[e.dst]: + redefined_symbols[reachable] |= redefined + + # Replace everything but the redefined symbols + for state in sdfg.nodes(): + per_state_mapping = {k: v for k, v in mapping.items() if k not in redefined_symbols[state]} + symbolic.safe_replace(per_state_mapping, state.replace_dict) + for e in sdfg.out_edges(state): + symbolic.safe_replace(per_state_mapping, lambda d: e.data.replace_dict(d, replace_keys=False)) + + def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[Memlet], tn.ScheduleTreeNode]: """ Creates a dictionary mapping edges to their corresponding schedule tree nodes, if relevant. @@ -363,8 +391,7 @@ def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: nested_array_mapping = {} # Replace symbols and memlets in nested SDFGs to match the namespace of the parent SDFG - # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes - symbolic.safe_replace(node.symbol_mapping, node.sdfg.replace_dict) + replace_symbols_until_set(node) # Create memlets for nested SDFG mapping, or nview schedule nodes if slice cannot be determined for e in state.all_edges(node): diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index b42df62ea4..62d2d73815 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -710,7 +710,7 @@ def replace_dict(self, if replace_in_graph: # Replace in inter-state edges for edge in self.edges(): - edge.data.replace_dict(repldict) + edge.data.replace_dict(repldict, replace_keys=replace_keys) # Replace in states for state in self.nodes(): diff --git a/tests/schedule_tree/naming_test.py b/tests/schedule_tree/naming_test.py index 702db6a4cf..517716e1bc 100644 --- a/tests/schedule_tree/naming_test.py +++ b/tests/schedule_tree/naming_test.py @@ -147,6 +147,7 @@ def test_edgecase_symbol_mapping(): nsdfg.add_symbol('N', dace.int64) nsdfg.add_symbol('k', dace.int64) nstate = nsdfg.add_state() + nstate.add_tasklet('dosomething', {}, {}, 'print(k)', side_effects=True) nstate2 = nsdfg.add_state() nstate3 = nsdfg.add_state() nsdfg.add_edge(nstate, nstate2, dace.InterstateEdge(assignments={'k': 'M + 1'})) @@ -157,13 +158,15 @@ def test_edgecase_symbol_mapping(): stree = as_schedule_tree(sdfg) # k is reassigned internally, so that should be preserved - assert len(stree.children) == 2 - assert isinstance(stree.children[0], tn.AssignNode) - assert stree.children[0].name == 'k' - assert stree.children[0].value.as_string == '(N + 1)' + assert len(stree.children) == 3 + assert isinstance(stree.children[0], tn.TaskletNode) + assert 'M + 1' in stree.children[0].node.code.as_string assert isinstance(stree.children[1], tn.AssignNode) - assert stree.children[1].name == 'l' - assert stree.children[1].value.as_string in ('k', '(N + 1)') + assert stree.children[1].name == 'k' + assert stree.children[1].value.as_string == '(N + 1)' + assert isinstance(stree.children[2], tn.AssignNode) + assert stree.children[2].name == 'l' + assert stree.children[2].value.as_string in ('k', '(N + 1)') def test_clash_iteration_symbols(): From 71a55c9cb85a21bbc315d78a81889aa742c77f05 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 27 Jul 2023 00:32:11 -0700 Subject: [PATCH 80/98] Test both simplified and unsimplified modes --- tests/schedule_tree/nesting_test.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/schedule_tree/nesting_test.py b/tests/schedule_tree/nesting_test.py index f5d71d7b5c..16c9cc3b9d 100644 --- a/tests/schedule_tree/nesting_test.py +++ b/tests/schedule_tree/nesting_test.py @@ -140,7 +140,8 @@ def tester(a: dace.float64[40]): assert str(copy.memlet.dst_subset) == '1:21' -def test_dealias_memlet_composition(): +@pytest.mark.parametrize('simplify', (False, True)) +def test_dealias_memlet_composition(simplify): def nester2(c): c[2] = 1 @@ -152,12 +153,11 @@ def nester1(b): def tester(a: dace.float64[N, N]): nester1(a[:, 1]) - # Simplifying yields a different SDFG due to views, so testing is slightly different - simplified = dace.Config.get_bool('optimizer', 'automatic_simplification') - - sdfg = tester.to_sdfg() + sdfg = tester.to_sdfg(simplify=simplify) stree = as_schedule_tree(sdfg) - if simplified: + + # Simplifying yields a different SDFG due to views, so testing is slightly different + if simplify: assert len(stree.children) == 1 tasklet = stree.children[0] assert isinstance(tasklet, tn.TaskletNode) From 95c8d86029ddd3640fe9d69d6f2a2519a8bc0d06 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 27 Jul 2023 01:01:50 -0700 Subject: [PATCH 81/98] Make data container dealiasing robust to conflicting replacements --- dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py | 8 ++++---- tests/schedule_tree/nesting_test.py | 3 +++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 11b90dbfd0..94a163b633 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -1,4 +1,5 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +from collections import defaultdict import copy from typing import Dict, List, Set import dace @@ -7,8 +8,10 @@ from dace.sdfg.sdfg import InterstateEdge, SDFG from dace.sdfg.state import SDFGState from dace.sdfg import utils as sdutil, graph as gr +from dace.sdfg.replace import replace_datadesc_names from dace.frontend.python.astutils import negate_expr from dace.sdfg.analysis.schedule_tree import treenodes as tn, passes as stpasses +from dace.transformation.passes.analysis import StateReachability from dace.transformation.helpers import unsqueeze_memlet from dace.properties import CodeBlock from dace.memlet import Memlet @@ -88,7 +91,7 @@ def dealias_sdfg(sdfg: SDFG): e.data.other_subset = subsets.Range.from_array(parent_arr) if replacements: - nsdfg.replace_dict(replacements) + symbolic.safe_replace(replacements, lambda d: replace_datadesc_names(nsdfg, d), value_as_string=True) parent_node.in_connectors = { replacements[c] if c in replacements else c: t for c, t in parent_node.in_connectors.items() @@ -238,9 +241,6 @@ def replace_symbols_until_set(nsdfg: dace.nodes.NestedSDFG): Replaces symbol values in a nested SDFG until their value has been reset. This is used for matching symbol namespaces between an SDFG and a nested SDFG. """ - from collections import defaultdict - from dace.transformation.passes.analysis import StateReachability - mapping = nsdfg.symbol_mapping sdfg = nsdfg.sdfg reachable_states = StateReachability().apply_pass(sdfg, {})[sdfg.sdfg_id] diff --git a/tests/schedule_tree/nesting_test.py b/tests/schedule_tree/nesting_test.py index 16c9cc3b9d..eb8c4df3ae 100644 --- a/tests/schedule_tree/nesting_test.py +++ b/tests/schedule_tree/nesting_test.py @@ -193,6 +193,9 @@ def test_dealias_interstate_edge(): sdfg.validate() stree = as_schedule_tree(sdfg) nodes = list(stree.preorder_traversal())[1:] + assert [type(n) for n in nodes] == [tn.StateIfScope, tn.GotoNode, tn.AssignNode] + assert 'A[2]' in nodes[0].condition.as_string + assert 'B[4]' in nodes[-1].value.as_string if __name__ == '__main__': From 1743112181569de98065cbfc4b61bbf84c1ef79b Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 27 Jul 2023 01:11:07 -0700 Subject: [PATCH 82/98] One more test --- tests/schedule_tree/nesting_test.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/tests/schedule_tree/nesting_test.py b/tests/schedule_tree/nesting_test.py index eb8c4df3ae..ab96d48efd 100644 --- a/tests/schedule_tree/nesting_test.py +++ b/tests/schedule_tree/nesting_test.py @@ -119,6 +119,29 @@ def test_stree_copy_different_scope(dst_subset): def test_dealias_nested_call(): + @dace.program + def nester(a, b): + b[:] = a + + @dace.program + def tester(a: dace.float64[40], b: dace.float64[40]): + nester(b[1:21], a[10:30]) + + sdfg = tester.to_sdfg(simplify=False) + sdfg.apply_transformations_repeated(RemoveSliceView) + + stree = as_schedule_tree(sdfg) + assert len(stree.children) == 1 + copy = stree.children[0] + assert isinstance(copy, tn.CopyNode) + assert copy.target == 'a' + assert copy.memlet.data == 'b' + assert str(copy.memlet.src_subset) == '1:21' + assert str(copy.memlet.dst_subset) == '10:30' + + +def test_dealias_nested_call_samearray(): + @dace.program def nester(a, b): b[:] = a @@ -136,8 +159,8 @@ def tester(a: dace.float64[40]): assert isinstance(copy, tn.CopyNode) assert copy.target == 'a' assert copy.memlet.data == 'a' - assert str(copy.memlet.src_subset) == '10:30' - assert str(copy.memlet.dst_subset) == '1:21' + assert str(copy.memlet.src_subset) == '1:21' + assert str(copy.memlet.dst_subset) == '10:30' @pytest.mark.parametrize('simplify', (False, True)) @@ -207,5 +230,6 @@ def test_dealias_interstate_edge(): test_stree_copy_different_scope(False) test_stree_copy_different_scope(True) test_dealias_nested_call() + test_dealias_nested_call_samearray() test_dealias_memlet_composition() test_dealias_interstate_edge() From 5fa505588b72e29526797fc882f4d04493b8d782 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 30 Jul 2023 11:13:54 -0700 Subject: [PATCH 83/98] Schedule tree: Fix memlets replacement not addressing both input and output connectors --- .../analysis/schedule_tree/sdfg_to_tree.py | 51 +++++++++++++++---- 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 94a163b633..9bf966f059 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -134,31 +134,56 @@ def normalize_memlet(sdfg: SDFG, state: SDFGState, original: gr.MultiConnectorEd memlet.data = data memlet.subset = new_subset memlet.other_subset = new_osubset + memlet._is_data_src = True return memlet -def replace_memlets(sdfg: SDFG, array_mapping: Dict[str, Memlet]): +def replace_memlets(sdfg: SDFG, input_mapping: Dict[str, Memlet], output_mapping: Dict[str, Memlet]): """ Replaces all uses of data containers in memlets and interstate edges in an SDFG. :param sdfg: The SDFG. - :param array_mapping: A mapping from internal data descriptor names to external memlets. + :param input_mapping: A mapping from internal data descriptor names to external input memlets. + :param output_mapping: A mapping from internal data descriptor names to external output memlets. """ # TODO: Support Interstate edges for state in sdfg.states(): for e in state.edges(): - if e.data.data in array_mapping: - e.data = unsqueeze_memlet(e.data, array_mapping[e.data.data]) + mpath = state.memlet_path(e) + src = mpath[0].src + dst = mpath[-1].dst + memlet = e.data + if isinstance(src, dace.nodes.AccessNode) and src.data in input_mapping: + memlet = unsqueeze_memlet(memlet, input_mapping[src.data], use_src_subset=True) + if isinstance(dst, dace.nodes.AccessNode) and dst.data in output_mapping: + memlet = unsqueeze_memlet(memlet, output_mapping[dst.data], use_dst_subset=True) + + # Other cases + if memlet is e.data: + if e.data.data in input_mapping: + memlet = unsqueeze_memlet(memlet, input_mapping[e.data.data]) + elif e.data.data in output_mapping: + memlet = unsqueeze_memlet(memlet, output_mapping[e.data.data]) + + e.data = memlet + for e in sdfg.edges(): repl_dict = dict() syms = e.data.read_symbols() for memlet in e.data.get_read_memlets(sdfg.arrays): - if memlet.data in array_mapping: - repl_dict[str(memlet)] = unsqueeze_memlet(memlet, array_mapping[memlet.data]) + if memlet.data in input_mapping or memlet.data in output_mapping: + # If array name is both in the input connectors and output connectors with different + # memlets, this is undefined behavior. Prefer output + if memlet.data in input_mapping: + mapping = input_mapping + if memlet.data in output_mapping: + mapping = output_mapping + + repl_dict[str(memlet)] = unsqueeze_memlet(memlet, mapping[memlet.data]) if memlet.data in syms: syms.remove(memlet.data) for s in syms: - if s in array_mapping: - repl_dict[s] = str(array_mapping[s]) + if s in input_mapping: + repl_dict[s] = str(input_mapping[s]) e.data.replace_dict(repl_dict) @@ -388,7 +413,8 @@ def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: elif isinstance(node, dace.nodes.ExitNode): result = scopes.pop() elif isinstance(node, dace.nodes.NestedSDFG): - nested_array_mapping = {} + nested_array_mapping_input = {} + nested_array_mapping_output = {} # Replace symbols and memlets in nested SDFGs to match the namespace of the parent SDFG replace_symbols_until_set(node) @@ -407,7 +433,10 @@ def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: if expanded: # "newaxis" slices will be seen as views (for now) no_mapping = True else: - nested_array_mapping[conn] = e.data + if e.dst is node: + nested_array_mapping_input[conn] = e.data + else: + nested_array_mapping_output[conn] = e.data if no_mapping: # Must use view (nview = nested SDFG view) result.append( @@ -417,7 +446,7 @@ def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: src_desc=sdfg.arrays[e.data.data], view_desc=node.sdfg.arrays[conn])) - replace_memlets(node.sdfg, nested_array_mapping) + replace_memlets(node.sdfg, nested_array_mapping_input, nested_array_mapping_output) # Insert the nested SDFG flattened nested_stree = as_schedule_tree(node.sdfg, in_place=True, toplevel=False) From ef7aa4efc21c7b60930cc06955b61924d809d646 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sat, 23 Sep 2023 14:41:19 -0700 Subject: [PATCH 84/98] Fix dealiasing for data->data edges and interstate edges --- .../analysis/schedule_tree/sdfg_to_tree.py | 85 +++++++++++++++---- dace/sdfg/analysis/schedule_tree/treenodes.py | 2 +- dace/sdfg/state.py | 17 ++-- tests/schedule_tree/nesting_test.py | 4 +- 4 files changed, 80 insertions(+), 28 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 9bf966f059..08a939530e 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -7,7 +7,7 @@ from dace.codegen import control_flow as cf from dace.sdfg.sdfg import InterstateEdge, SDFG from dace.sdfg.state import SDFGState -from dace.sdfg import utils as sdutil, graph as gr +from dace.sdfg import utils as sdutil, graph as gr, nodes as nd from dace.sdfg.replace import replace_datadesc_names from dace.frontend.python.astutils import negate_expr from dace.sdfg.analysis.schedule_tree import treenodes as tn, passes as stpasses @@ -19,6 +19,11 @@ import time import sys +NODE_TO_SCOPE_TYPE = { + dace.nodes.MapEntry: tn.MapScope, + dace.nodes.ConsumeEntry: tn.ConsumeScope, + dace.nodes.PipelineEntry: tn.PipelineScope, +} def dealias_sdfg(sdfg: SDFG): for nsdfg in sdfg.all_sdfgs_recursive(): @@ -68,8 +73,35 @@ def dealias_sdfg(sdfg: SDFG): nsdfg.arrays[name] = child_arr for state in nsdfg.states(): for e in state.edges(): - if e.data.data in child_names: - e.data = unsqueeze_memlet(e.data, parent_edges[e.data.data].data) + if not state.is_leaf_memlet(e): + continue + + mpath = state.memlet_path(e) + src, dst = mpath[0].src, mpath[-1].dst + + # We need to take directionality of the memlet into account and unsqueeze either to source or + # destination subset + if isinstance(src, nd.AccessNode) and src.data in child_names: + src_data = src.data + new_src_memlet = unsqueeze_memlet(e.data, parent_edges[src.data].data, use_src_subset=True) + else: + src_data = None + new_src_memlet = e.data + # We need to take directionality of the memlet into account + if isinstance(dst, nd.AccessNode) and dst.data in child_names: + dst_data = dst.data + new_dst_memlet = unsqueeze_memlet(e.data, parent_edges[dst.data].data, use_dst_subset=True) + else: + dst_data = None + new_dst_memlet = e.data + + e.data.src_subset = new_src_memlet.subset + e.data.dst_subset = new_dst_memlet.subset + if e.data.data == src_data: + e.data.data = new_src_memlet.data + elif e.data.data == dst_data: + e.data.data = new_dst_memlet.data + for e in nsdfg.edges(): repl_dict = dict() syms = e.data.read_symbols() @@ -145,7 +177,6 @@ def replace_memlets(sdfg: SDFG, input_mapping: Dict[str, Memlet], output_mapping :param input_mapping: A mapping from internal data descriptor names to external input memlets. :param output_mapping: A mapping from internal data descriptor names to external output memlets. """ - # TODO: Support Interstate edges for state in sdfg.states(): for e in state.edges(): mpath = state.memlet_path(e) @@ -153,18 +184,34 @@ def replace_memlets(sdfg: SDFG, input_mapping: Dict[str, Memlet], output_mapping dst = mpath[-1].dst memlet = e.data if isinstance(src, dace.nodes.AccessNode) and src.data in input_mapping: - memlet = unsqueeze_memlet(memlet, input_mapping[src.data], use_src_subset=True) + src_data = src.data + src_memlet = unsqueeze_memlet(memlet, input_mapping[src.data], use_src_subset=True) + else: + src_data = None + src_memlet = None if isinstance(dst, dace.nodes.AccessNode) and dst.data in output_mapping: - memlet = unsqueeze_memlet(memlet, output_mapping[dst.data], use_dst_subset=True) + dst_data = dst.data + dst_memlet = unsqueeze_memlet(memlet, output_mapping[dst.data], use_dst_subset=True) + else: + dst_data = None + dst_memlet = None - # Other cases - if memlet is e.data: + # Other cases (code->code) + if src_data is None and dst_data is None: if e.data.data in input_mapping: memlet = unsqueeze_memlet(memlet, input_mapping[e.data.data]) elif e.data.data in output_mapping: memlet = unsqueeze_memlet(memlet, output_mapping[e.data.data]) - - e.data = memlet + e.data = memlet + else: + if src_memlet is not None: + memlet.src_subset = src_memlet.subset + if dst_memlet is not None: + memlet.dst_subset = dst_memlet.subset + if memlet.data == src_data: + memlet.data = src_memlet.data + elif memlet.data == dst_data: + memlet.data = dst_memlet.data for e in sdfg.edges(): repl_dict = dict() @@ -178,13 +225,22 @@ def replace_memlets(sdfg: SDFG, input_mapping: Dict[str, Memlet], output_mapping if memlet.data in output_mapping: mapping = output_mapping - repl_dict[str(memlet)] = unsqueeze_memlet(memlet, mapping[memlet.data]) + repl_dict[str(memlet)] = str(unsqueeze_memlet(memlet, mapping[memlet.data])) if memlet.data in syms: syms.remove(memlet.data) for s in syms: if s in input_mapping: repl_dict[s] = str(input_mapping[s]) - e.data.replace_dict(repl_dict) + + # Manual replacement with strings + # TODO(later): Would be MUCH better to use e.data.replace_dict(repl_dict, replace_keys=False) + for find, replace in repl_dict.items(): + for k, v in e.data.assignments.items(): + if find in v: + e.data.assignments[k] = v.replace(find, replace) + condstr = e.data.condition.as_string + if find in condstr: + e.data.condition.as_string = condstr.replace(find, replace) def remove_name_collisions(sdfg: SDFG): @@ -381,11 +437,6 @@ def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: :return: A string for the whole state """ result: List[tn.ScheduleTreeNode] = [] - NODE_TO_SCOPE_TYPE = { - dace.nodes.MapEntry: tn.MapScope, - dace.nodes.ConsumeEntry: tn.ConsumeScope, - dace.nodes.PipelineEntry: tn.PipelineScope, - } sdfg = state.parent edge_to_stree: Dict[gr.MultiConnectorEdge[Memlet], tn.ScheduleTreeNode] = prepare_schedule_tree_edges(state) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index b96be06832..dc72105fc9 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -67,7 +67,7 @@ def preorder_traversal(self) -> Iterator['ScheduleTreeNode']: for child in self.children: yield from child.preorder_traversal() - # TODO: Get input/output memlets? + # TODO: Helper function that gets input/output memlets of the scope @dataclass diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 27e8536342..538f2114b9 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -411,6 +411,14 @@ def scope_children(self, ################################################################### # Query, subgraph, and replacement methods + def is_leaf_memlet(self, e): + if isinstance(e.src, nd.ExitNode) and e.src_conn and e.src_conn.startswith('OUT_'): + return False + if isinstance(e.dst, nd.EntryNode) and e.dst_conn and e.dst_conn.startswith('IN_'): + return False + return True + + def used_symbols(self, all_symbols: bool) -> Set[str]: """ Returns a set of symbol names that are used in the state. @@ -444,16 +452,9 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: freesyms |= n.free_symbols # Free symbols from memlets - def _is_leaf_memlet(e): - if isinstance(e.src, nd.ExitNode) and e.src_conn and e.src_conn.startswith('OUT_'): - return False - if isinstance(e.dst, nd.EntryNode) and e.dst_conn and e.dst_conn.startswith('IN_'): - return False - return True - for e in self.edges(): # If used for code generation, only consider memlet tree leaves - if not all_symbols and not _is_leaf_memlet(e): + if not all_symbols and not self.is_leaf_memlet(e): continue freesyms |= e.data.used_symbols(all_symbols) diff --git a/tests/schedule_tree/nesting_test.py b/tests/schedule_tree/nesting_test.py index ab96d48efd..a68128ebf9 100644 --- a/tests/schedule_tree/nesting_test.py +++ b/tests/schedule_tree/nesting_test.py @@ -8,7 +8,6 @@ from dace.transformation.dataflow import RemoveSliceView import pytest -from typing import List N = dace.symbol('N') T = dace.symbol('T') @@ -231,5 +230,6 @@ def test_dealias_interstate_edge(): test_stree_copy_different_scope(True) test_dealias_nested_call() test_dealias_nested_call_samearray() - test_dealias_memlet_composition() + test_dealias_memlet_composition(False) + test_dealias_memlet_composition(True) test_dealias_interstate_edge() From 6cae4a83a0e3e3ef9b87c10bdd5a5dc41aeb387b Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sat, 23 Sep 2023 18:48:38 -0700 Subject: [PATCH 85/98] Schedule views to show in order of viewing rather than graph --- dace/sdfg/analysis/schedule_tree/passes.py | 5 +- .../analysis/schedule_tree/sdfg_to_tree.py | 83 ++++++++++++++++--- dace/sdfg/analysis/schedule_tree/treenodes.py | 2 + tests/schedule_tree/nesting_test.py | 3 +- tests/schedule_tree/schedule_test.py | 26 +++++- 5 files changed, 101 insertions(+), 18 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/passes.py b/dace/sdfg/analysis/schedule_tree/passes.py index 52a58adc32..feda258e20 100644 --- a/dace/sdfg/analysis/schedule_tree/passes.py +++ b/dace/sdfg/analysis/schedule_tree/passes.py @@ -51,10 +51,7 @@ def remove_empty_scopes(stree: tn.ScheduleTreeScope): class RemoveEmptyScopes(tn.ScheduleNodeTransformer): - def visit(self, node: tn.ScheduleTreeNode): - if not isinstance(node, tn.ScheduleTreeScope): - return super().visit(node) - + def visit_scope(self, node: tn.ScheduleTreeScope): if len(node.children) == 0: return None diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 08a939530e..37a955f5c3 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -16,6 +16,7 @@ from dace.properties import CodeBlock from dace.memlet import Memlet +import networkx as nx import time import sys @@ -25,6 +26,7 @@ dace.nodes.PipelineEntry: tn.PipelineScope, } + def dealias_sdfg(sdfg: SDFG): for nsdfg in sdfg.all_sdfgs_recursive(): @@ -345,10 +347,12 @@ def replace_symbols_until_set(nsdfg: dace.nodes.NestedSDFG): def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[Memlet], tn.ScheduleTreeNode]: """ Creates a dictionary mapping edges to their corresponding schedule tree nodes, if relevant. + This handles view edges, reference sets, and dynamic map inputs. :param state: The state. """ result: Dict[gr.MultiConnectorEdge[Memlet], tn.ScheduleTreeNode] = {} + scope_to_edges: Dict[nd.EntryNode, List[gr.MultiConnectorEdge[Memlet]]] = defaultdict(list) edges_to_ignore = set() sdfg = state.parent @@ -378,6 +382,8 @@ def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[ if e is vedge: viewed_node = sdutil.get_view_node(state, e.src) result[e] = _make_view_node(state, e, e.src.data, viewed_node.data) + scope = state.entry_node(e.dst if mtree.downwards else e.src) + scope_to_edges[scope].append(e) continue if isinstance(e.dst, dace.nodes.AccessNode): desc = e.dst.desc(sdfg) @@ -386,6 +392,8 @@ def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[ if e is vedge: viewed_node = sdutil.get_view_node(state, e.dst) result[e] = _make_view_node(state, e, e.dst.data, viewed_node.data) + scope = state.entry_node(e.dst if mtree.downwards else e.src) + scope_to_edges[scope].append(e) continue # 2. Check for reference sets @@ -395,6 +403,8 @@ def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[ memlet=e.data, src_desc=sdfg.arrays[e.data.data], ref_desc=sdfg.arrays[e.dst.data]) + scope = state.entry_node(e.dst if mtree.downwards else e.src) + scope_to_edges[scope].append(e) continue # 3. Check for copies @@ -426,7 +436,10 @@ def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[ new_memlet = normalize_memlet(sdfg, state, e, outermost_node.data) result[e] = tn.CopyNode(target=target_name, memlet=new_memlet) - return result + scope = state.entry_node(e.dst if mtree.downwards else e.src) + scope_to_edges[scope].append(e) + + return result, scope_to_edges def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: @@ -439,9 +452,15 @@ def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: result: List[tn.ScheduleTreeNode] = [] sdfg = state.parent - edge_to_stree: Dict[gr.MultiConnectorEdge[Memlet], tn.ScheduleTreeNode] = prepare_schedule_tree_edges(state) + edge_to_stree: Dict[gr.MultiConnectorEdge[Memlet], tn.ScheduleTreeNode] + scope_to_edges: Dict[nd.EntryNode, List[gr.MultiConnectorEdge[Memlet]]] + edge_to_stree, scope_to_edges = prepare_schedule_tree_edges(state) edges_to_ignore = set() + # Handle all unscoped edges to generate output views + views = _generate_views_in_scope(scope_to_edges[None], edge_to_stree, sdfg, state) + result.extend(views) + scopes: List[List[tn.ScheduleTreeNode]] = [] for node in sdutil.scope_aware_topological_sort(state): if isinstance(node, dace.nodes.EntryNode): @@ -453,6 +472,10 @@ def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: result.append(edge_to_stree[e]) edges_to_ignore.add(e) + # Handle all scoped edges to generate (views) + views = _generate_views_in_scope(scope_to_edges[node], edge_to_stree, sdfg, state) + result.extend(views) + # Create scope node and add to stack scopes.append(result) subnodes = [] @@ -466,6 +489,7 @@ def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: elif isinstance(node, dace.nodes.NestedSDFG): nested_array_mapping_input = {} nested_array_mapping_output = {} + generated_nviews = set() # Replace symbols and memlets in nested SDFGs to match the namespace of the parent SDFG replace_symbols_until_set(node) @@ -490,12 +514,14 @@ def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: nested_array_mapping_output[conn] = e.data if no_mapping: # Must use view (nview = nested SDFG view) - result.append( - tn.NView(target=conn, - source=e.data.data, - memlet=e.data, - src_desc=sdfg.arrays[e.data.data], - view_desc=node.sdfg.arrays[conn])) + if conn not in generated_nviews: + result.append( + tn.NView(target=conn, + source=e.data.data, + memlet=e.data, + src_desc=sdfg.arrays[e.data.data], + view_desc=node.sdfg.arrays[conn])) + generated_nviews.add(conn) replace_memlets(node.sdfg, nested_array_mapping_input, nested_array_mapping_output) @@ -521,10 +547,13 @@ def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: # result.append(tn.LibraryCall(sdfg=sdfg, node=node, in_memlets=in_memlets, out_memlets=out_memlets)) elif isinstance(node, dace.nodes.AccessNode): # If one of the neighboring edges has a schedule tree node attached to it, use that + # (except for views, which were generated above) for e in state.all_edges(node): if e in edges_to_ignore: continue if e in edge_to_stree: + if isinstance(edge_to_stree[e], tn.ViewNode): + continue result.append(edge_to_stree[e]) edges_to_ignore.add(e) @@ -533,11 +562,43 @@ def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: return result +def _generate_views_in_scope(edges: List[gr.MultiConnectorEdge[Memlet]], + edge_to_stree: Dict[gr.MultiConnectorEdge[Memlet], tn.ScheduleTreeNode], sdfg: SDFG, + state: SDFGState) -> List[tn.ScheduleTreeNode]: + """ + Generates all view and reference set edges in the correct order. This function is intended to be used + at the beginning of a scope. + """ + result: List[tn.ScheduleTreeNode] = [] + + # Make a dependency graph of all the views + g = nx.DiGraph() + node_to_stree = {} + for e in edges: + if e not in edge_to_stree: + continue + st = edge_to_stree[e] + if not isinstance(st, tn.ViewNode): + continue + g.add_edge(st.source, st.target) + node_to_stree[st.target] = st + + # Traverse in order and deduplicate + already_generated = set() + for n in nx.topological_sort(g): + if n in node_to_stree and n not in already_generated: + result.append(node_to_stree[n]) + already_generated.add(n) + + return result + + def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) -> tn.ScheduleTreeScope: """ - Converts an SDFG into a schedule tree. The schedule tree is a tree of nodes that represent the execution order of the SDFG. - Each node in the tree can either represent a single statement (symbol assignment, tasklet, copy, library node, etc.) or - a ``ScheduleTreeScope`` block (map, for-loop, pipeline, etc.) that contains other nodes. + Converts an SDFG into a schedule tree. The schedule tree is a tree of nodes that represent the execution order of + the SDFG. + Each node in the tree can either represent a single statement (symbol assignment, tasklet, copy, library node, etc.) + or a ``ScheduleTreeScope`` block (map, for-loop, pipeline, etc.) that contains other nodes. It can be used to generate code from an SDFG, or to perform schedule transformations on the SDFG. For example, erasing an empty if branch, or merging two consecutive for-loops. The SDFG can then be reconstructed via the diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index dc72105fc9..52b6db0361 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -375,6 +375,8 @@ def visit(self, node: ScheduleTreeNode): """Visit a node.""" if isinstance(node, list): return [self.visit(snode) for snode in node] + if isinstance(node, ScheduleTreeScope) and hasattr(self, 'visit_scope'): + return self.visit_scope(node) method = 'visit_' + node.__class__.__name__ visitor = getattr(self, method, self.generic_visit) diff --git a/tests/schedule_tree/nesting_test.py b/tests/schedule_tree/nesting_test.py index a68128ebf9..161f15d6c1 100644 --- a/tests/schedule_tree/nesting_test.py +++ b/tests/schedule_tree/nesting_test.py @@ -186,9 +186,8 @@ def tester(a: dace.float64[N, N]): assert str(next(iter(tasklet.out_memlets.values()))) == 'a[N - 3, 1]' else: assert len(stree.children) == 3 - # TODO: Should views precede tasklet? stree_nodes = list(stree.preorder_traversal())[1:] - assert [type(n) for n in stree_nodes] == [tn.TaskletNode, tn.ViewNode, tn.ViewNode] + assert [type(n) for n in stree_nodes] == [tn.ViewNode, tn.ViewNode, tn.TaskletNode] def test_dealias_interstate_edge(): diff --git a/tests/schedule_tree/schedule_test.py b/tests/schedule_tree/schedule_test.py index 6d41420856..09779c670f 100644 --- a/tests/schedule_tree/schedule_test.py +++ b/tests/schedule_tree/schedule_test.py @@ -128,7 +128,7 @@ def main(a: dace.float64[20, 10]): sdfg = main.to_sdfg() stree = as_schedule_tree(sdfg) - assert any(isinstance(node, tn.NView) for node in stree.children) + assert isinstance(stree.children[0], tn.NView) def test_irreducible_sub_sdfg(): @@ -252,6 +252,29 @@ def spmv(A_row: dace.uint32[H + 1], A_col: dace.uint32[nnz], A_val: dace.float32 assert isinstance(dynrangemap, tn.MapScope) +def test_multiview(): + sdfg = dace.SDFG('tester') + sdfg.add_array('A', [20, 20], dace.float64) + sdfg.add_array('B', [20, 20], dace.float64) + sdfg.add_view('Av', [400], dace.float64) + sdfg.add_view('Avv', [10, 40], dace.float64) + sdfg.add_view('Bv', [400], dace.float64) + sdfg.add_view('Bvv', [10, 40], dace.float64) + state = sdfg.add_state() + av = state.add_access('Av') + bv = state.add_access('Bv') + bvv = state.add_access('Bvv') + avv = state.add_access('Avv') + state.add_edge(state.add_read('A'), None, av, None, dace.Memlet('A[0:20, 0:20]')) + state.add_edge(av, None, avv, 'views', dace.Memlet('Av[0:400]')) + state.add_edge(avv, None, bvv, None, dace.Memlet('Avv[0:10, 0:40]')) + state.add_edge(bvv, 'views', bv, None, dace.Memlet('Bv[0:400]')) + state.add_edge(bv, 'views', state.add_write('B'), None, dace.Memlet('Bv[0:400]')) + + stree = as_schedule_tree(sdfg) + assert [type(n) for n in stree.children] == [tn.ViewNode, tn.ViewNode, tn.ViewNode, tn.ViewNode, tn.CopyNode] + + if __name__ == '__main__': test_for_in_map_in_for() test_libnode() @@ -263,3 +286,4 @@ def spmv(A_row: dace.uint32[H + 1], A_col: dace.uint32[nnz], A_val: dace.float32 test_reference() test_code_to_code() test_dyn_map_range() + test_multiview() From 1b6a912015bac16755b80f7fe534ae31e19851fe Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sat, 23 Sep 2023 19:30:33 -0700 Subject: [PATCH 86/98] Fix constructor --- dace/libraries/standard/nodes/reduce.py | 2 +- tests/symbol_dependent_transients_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dace/libraries/standard/nodes/reduce.py b/dace/libraries/standard/nodes/reduce.py index 45ca3cacb3..dd026ea62c 100644 --- a/dace/libraries/standard/nodes/reduce.py +++ b/dace/libraries/standard/nodes/reduce.py @@ -1578,7 +1578,7 @@ def __init__(self, @staticmethod def from_json(json_obj, context=None): - ret = Reduce("lambda a, b: a", None) + ret = Reduce('reduce', 'lambda a, b: a', None) dace.serialize.set_properties_from_json(ret, json_obj, context=context) return ret diff --git a/tests/symbol_dependent_transients_test.py b/tests/symbol_dependent_transients_test.py index f718abf379..8033b6b196 100644 --- a/tests/symbol_dependent_transients_test.py +++ b/tests/symbol_dependent_transients_test.py @@ -45,7 +45,7 @@ def _make_sdfg(name, storage=dace.dtypes.StorageType.CPU_Heap, isview=False): body2_state.add_nedge(read_a, read_tmp1, dace.Memlet(f'A[2:{N}-2, 2:{N}-2, i:{N}]')) else: read_tmp1 = body2_state.add_read('tmp1') - rednode = standard.Reduce(wcr='lambda a, b : a + b', identity=0) + rednode = standard.Reduce('sum', wcr='lambda a, b : a + b', identity=0) if storage == dace.dtypes.StorageType.GPU_Global: rednode.implementation = 'CUDA (device)' elif storage == dace.dtypes.StorageType.FPGA_Global: From b2d6cd3b45cf1293eab089f1f2cd6dbaffaff9e7 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 24 Sep 2023 04:04:04 -0700 Subject: [PATCH 87/98] Fix incorrect subset modification in dealiasing --- dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 37a955f5c3..b33568e864 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -88,17 +88,19 @@ def dealias_sdfg(sdfg: SDFG): new_src_memlet = unsqueeze_memlet(e.data, parent_edges[src.data].data, use_src_subset=True) else: src_data = None - new_src_memlet = e.data + new_src_memlet = None # We need to take directionality of the memlet into account if isinstance(dst, nd.AccessNode) and dst.data in child_names: dst_data = dst.data new_dst_memlet = unsqueeze_memlet(e.data, parent_edges[dst.data].data, use_dst_subset=True) else: dst_data = None - new_dst_memlet = e.data + new_dst_memlet = None - e.data.src_subset = new_src_memlet.subset - e.data.dst_subset = new_dst_memlet.subset + if new_src_memlet is not None: + e.data.src_subset = new_src_memlet.subset + if new_dst_memlet is not None: + e.data.dst_subset = new_dst_memlet.subset if e.data.data == src_data: e.data.data = new_src_memlet.data elif e.data.data == dst_data: From 92c7a14ef2ee49bb510de1a98ced0fc825ced071 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 24 Sep 2023 15:56:26 -0700 Subject: [PATCH 88/98] Fix used_symbols for the case of a symbol that is in SDFG.symbols but never used --- dace/sdfg/sdfg.py | 27 +++++++------------ .../passes/constant_propagation.py | 2 +- 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index e588feb51c..077be93304 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -62,7 +62,7 @@ def __getitem__(self, key): token = tokens.pop(0) result = result.members[token] return result - + def __setitem__(self, key, val): if isinstance(key, str) and '.' in key: raise KeyError('NestedDict does not support setting nested keys') @@ -1335,24 +1335,12 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: defined_syms = set() free_syms = set() - # Exclude data descriptor names, constants, and shapes of global data descriptors - not_strictly_necessary_global_symbols = set() - for name, desc in self.arrays.items(): + # Exclude data descriptor names and constants + for name in self.arrays.keys(): defined_syms.add(name) - if not all_symbols: - used_desc_symbols = desc.used_symbols(all_symbols) - not_strictly_necessary = (desc.used_symbols(all_symbols=True) - used_desc_symbols) - not_strictly_necessary_global_symbols |= set(map(str, not_strictly_necessary)) - defined_syms |= set(self.constants_prop.keys()) - # Start with the set of SDFG free symbols - if all_symbols: - free_syms |= set(self.symbols.keys()) - else: - free_syms |= set(s for s in self.symbols.keys() if s not in not_strictly_necessary_global_symbols) - # Add free state symbols used_before_assignment = set() @@ -1378,6 +1366,11 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: # Remove symbols that were used before they were assigned defined_syms -= used_before_assignment + # Add the set of SDFG symbol parameters + # If all_symbols is False, those symbols would only be added in the case of non-Python tasklets + if all_symbols: + free_syms |= set(self.symbols.keys()) + # Subtract symbols defined in inter-state edges and constants return free_syms - defined_syms @@ -1498,7 +1491,7 @@ def signature_arglist(self, with_types=True, for_call=False, with_arrays=True, a """ arglist = arglist or self.arglist(scalars_only=not with_arrays) return [v.as_arg(name=k, with_types=with_types, for_call=for_call) for k, v in arglist.items()] - + def python_signature_arglist(self, with_types=True, for_call=False, with_arrays=True, arglist=None) -> List[str]: """ Returns a list of arguments necessary to call this SDFG, formatted as a list of Data-Centric Python definitions. @@ -1528,7 +1521,7 @@ def signature(self, with_types=True, for_call=False, with_arrays=True, arglist=N :param arglist: An optional cached argument list. """ return ", ".join(self.signature_arglist(with_types, for_call, with_arrays, arglist)) - + def python_signature(self, with_types=True, for_call=False, with_arrays=True, arglist=None) -> str: """ Returns a Data-Centric Python signature of this SDFG, used when generating code. diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index dd2523c005..9cec6d11af 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -118,7 +118,7 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = del edge.data.assignments[sym] # If symbols are never unknown any longer, remove from SDFG - fsyms = sdfg.free_symbols + fsyms = sdfg.used_symbols(all_symbols=False) result = {k: v for k, v in result.items() if k not in fsyms} for sym in result: if sym in sdfg.symbols: From 07b06f31bd75d143b9ce84a234b5d16abbfb52aa Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 24 Sep 2023 15:57:00 -0700 Subject: [PATCH 89/98] Do not inadvertently simplify expressions when computing free symbols --- dace/sdfg/sdfg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 077be93304..238d0b72c7 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -273,7 +273,7 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: rhs_symbols = set() for lhs, rhs in self.assignments.items(): # Always add LHS symbols to the set of candidate free symbols - rhs_symbols |= symbolic.free_symbols_and_functions(rhs) + rhs_symbols |= set(map(str, dace.symbolic.symbols_in_ast(ast.parse(rhs)))) # Add the RHS to the set of candidate defined symbols ONLY if it has not been read yet # This also solves the ordering issue that may arise in cases like the 3rd example above if lhs not in cond_symbols and lhs not in rhs_symbols: From 781699925586c2d6b275f7ea9419339737ca68dc Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 24 Sep 2023 17:25:11 -0700 Subject: [PATCH 90/98] Ignore symbol mappings to unused symbols in used_symbols and nested SDFGs --- dace/sdfg/nodes.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/dace/sdfg/nodes.py b/dace/sdfg/nodes.py index 28431deeea..39335dd90d 100644 --- a/dace/sdfg/nodes.py +++ b/dace/sdfg/nodes.py @@ -581,6 +581,10 @@ def from_json(json_obj, context=None): return ret def used_symbols(self, all_symbols: bool) -> Set[str]: + free_syms = set().union(*(map(str, pystr_to_symbolic(v).free_symbols) for v in self.location.values())) + + keys_to_use = set(self.symbol_mapping.keys()) + free_syms = set().union(*(map(str, pystr_to_symbolic(v).free_symbols) for v in self.symbol_mapping.values()), *(map(str, @@ -589,8 +593,12 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: # Filter out unused internal symbols from symbol mapping if not all_symbols: internally_used_symbols = self.sdfg.used_symbols(all_symbols=False) - free_syms &= internally_used_symbols - + keys_to_use &= internally_used_symbols + + free_syms |= set().union(*(map(str, + pystr_to_symbolic(v).free_symbols) for k, v in self.symbol_mapping.items() + if k in keys_to_use)) + return free_syms @property @@ -640,7 +648,7 @@ def validate(self, sdfg, state, references: Optional[Set[int]] = None, **context raise NameError('Data descriptor "%s" not found in nested SDFG connectors' % dname) if dname in connectors and desc.transient: raise NameError('"%s" is a connector but its corresponding array is transient' % dname) - + # Validate inout connectors from dace.sdfg import utils # Avoids circular import inout_connectors = self.in_connectors.keys() & self.out_connectors.keys() From 6fc589034955ce744a31d8d1a8fb9a574a2ece95 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 24 Sep 2023 18:09:51 -0700 Subject: [PATCH 91/98] Revert changes made to the Python frontend (these may belong in a different PR) --- dace/frontend/python/newast.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 9f57759b9d..2010aa6b1e 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -4941,10 +4941,6 @@ def visit_Subscript(self, node: ast.Subscript, inference: bool = False): # If this subscript originates from an external array, create the # subset in the edge going to the connector, as well as a local # reference to the subset - # old_node = node - # if isinstance(node.value, ast.Name): - # true_node = copy.deepcopy(old_node) - # true_node.value.id = true_name if (true_name not in self.sdfg.arrays and isinstance(node.value, ast.Name)): true_node = copy.deepcopy(node) true_node.value.id = true_name @@ -4961,9 +4957,9 @@ def visit_Subscript(self, node: ast.Subscript, inference: bool = False): rng.offset(rng, True) return self.sdfg.arrays[true_name].dtype, rng.size() if is_read: - new_name, new_rng = self._add_read_access(true_name, rng, node) + new_name, new_rng = self._add_read_access(name, rng, node) else: - new_name, new_rng = self._add_write_access(true_name, rng, node) + new_name, new_rng = self._add_write_access(name, rng, node) new_arr = self.sdfg.arrays[new_name] full_rng = subsets.Range.from_array(new_arr) if new_rng.ranges == full_rng.ranges: From 1c42e2466acab3cafcec77b3619eadfd52e6de83 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 24 Sep 2023 22:09:05 -0700 Subject: [PATCH 92/98] Remove line that should have been removed in commit 7816999 --- dace/sdfg/nodes.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/dace/sdfg/nodes.py b/dace/sdfg/nodes.py index 39335dd90d..3c8f38162f 100644 --- a/dace/sdfg/nodes.py +++ b/dace/sdfg/nodes.py @@ -585,11 +585,6 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: keys_to_use = set(self.symbol_mapping.keys()) - free_syms = set().union(*(map(str, - pystr_to_symbolic(v).free_symbols) for v in self.symbol_mapping.values()), - *(map(str, - pystr_to_symbolic(v).free_symbols) for v in self.location.values())) - # Filter out unused internal symbols from symbol mapping if not all_symbols: internally_used_symbols = self.sdfg.used_symbols(all_symbols=False) From dcac66ecf6859c68da2d1237801d74bb9d0a578d Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 24 Sep 2023 22:21:20 -0700 Subject: [PATCH 93/98] Fix typo --- dace/codegen/targets/framecode.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index dfdbbb392b..b1eb42fe60 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -886,8 +886,8 @@ def generate_code(self, # NOTE: NestedSDFGs frequently contain tautologies in their symbol mapping, e.g., `'i': i`. Do not # redefine the symbols in such cases. - if (not is_top_level and isvarName in sdfg.parent_nsdfg_node.symbol_mapping.keys() - and str(sdfg.parent_nsdfg_node.symbol_mapping[isvarName] == isvarName)): + if (not is_top_level and isvarName in sdfg.parent_nsdfg_node.symbol_mapping + and str(sdfg.parent_nsdfg_node.symbol_mapping[isvarName]) == str(isvarName)): continue isvar = data.Scalar(isvarType) callsite_stream.write('%s;\n' % (isvar.as_arg(with_types=True, name=isvarName)), sdfg) From bf9b0231a67551c6794514d84522e9177c2ad68e Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 24 Sep 2023 22:58:15 -0700 Subject: [PATCH 94/98] Better detection of free symbols in C++ tasklets --- dace/sdfg/nodes.py | 10 +++++- dace/sdfg/replace.py | 7 +++++ dace/sdfg/state.py | 35 +++++++++++++-------- dace/symbolic.py | 24 +++++++++++++- dace/transformation/passes/prune_symbols.py | 25 ++++++--------- 5 files changed, 71 insertions(+), 30 deletions(-) diff --git a/dace/sdfg/nodes.py b/dace/sdfg/nodes.py index 3c8f38162f..f60460c50e 100644 --- a/dace/sdfg/nodes.py +++ b/dace/sdfg/nodes.py @@ -342,6 +342,8 @@ class Tasklet(CodeNode): 'additional side effects on the system state (e.g., callback). ' 'Defaults to None, which lets the framework make assumptions based on ' 'the tasklet contents') + ignored_symbols = SetProperty(element_type=str, desc='A set of symbols to ignore when computing ' + 'the symbols used by this tasklet') def __init__(self, label, @@ -355,6 +357,7 @@ def __init__(self, code_exit="", location=None, side_effects=None, + ignored_symbols=None, debuginfo=None): super(Tasklet, self).__init__(label, location, inputs, outputs) @@ -365,6 +368,7 @@ def __init__(self, self.code_init = CodeBlock(code_init, dtypes.Language.CPP) self.code_exit = CodeBlock(code_exit, dtypes.Language.CPP) self.side_effects = side_effects + self.ignored_symbols = ignored_symbols or set() self.debuginfo = debuginfo @property @@ -393,7 +397,11 @@ def validate(self, sdfg, state): @property def free_symbols(self) -> Set[str]: - return self.code.get_free_symbols(self.in_connectors.keys() | self.out_connectors.keys()) + symbols_to_ignore = self.in_connectors.keys() | self.out_connectors.keys() + symbols_to_ignore |= self.ignored_symbols + + return self.code.get_free_symbols(symbols_to_ignore) + def has_side_effects(self, sdfg) -> bool: """ diff --git a/dace/sdfg/replace.py b/dace/sdfg/replace.py index 5e42830a75..4b36fad4fe 100644 --- a/dace/sdfg/replace.py +++ b/dace/sdfg/replace.py @@ -124,6 +124,7 @@ def replace_properties_dict(node: Any, if lang is dtypes.Language.CPP: # Replace in C++ code prefix = '' tokenized = tokenize_cpp.findall(code) + active_replacements = set() for name, new_name in reduced_repl.items(): if name not in tokenized: continue @@ -131,8 +132,14 @@ def replace_properties_dict(node: Any, # Use local variables and shadowing to replace replacement = f'auto {name} = {cppunparse.pyexpr2cpp(new_name)};\n' prefix = replacement + prefix + active_replacements.add(name) if prefix: propval.code = prefix + code + + # Ignore replaced symbols since they no longer exist as reads + if isinstance(node, dace.nodes.Tasklet): + node._ignored_symbols.update(active_replacements) + else: warnings.warn('Replacement of %s with %s was not made ' 'for string tasklet code of language %s' % (name, new_name, lang)) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 538f2114b9..8ad0c67bb8 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -27,6 +27,7 @@ if TYPE_CHECKING: import dace.sdfg.scope + def _getdebuginfo(old_dinfo=None) -> dtypes.DebugInfo: """ Returns a DebugInfo object for the position that called this function. @@ -417,7 +418,6 @@ def is_leaf_memlet(self, e): if isinstance(e.dst, nd.EntryNode) and e.dst_conn and e.dst_conn.startswith('IN_'): return False return True - def used_symbols(self, all_symbols: bool) -> Set[str]: """ @@ -438,13 +438,23 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: elif isinstance(n, nd.AccessNode): # Add data descriptor symbols freesyms |= set(map(str, n.desc(sdfg).used_symbols(all_symbols))) - elif (isinstance(n, nd.Tasklet) and n.language == dtypes.Language.Python): - # Consider callbacks defined as symbols as free - for stmt in n.code.code: - for astnode in ast.walk(stmt): - if (isinstance(astnode, ast.Call) and isinstance(astnode.func, ast.Name) - and astnode.func.id in sdfg.symbols): - freesyms.add(astnode.func.id) + elif isinstance(n, nd.Tasklet): + if n.language == dtypes.Language.Python: + # Consider callbacks defined as symbols as free + for stmt in n.code.code: + for astnode in ast.walk(stmt): + if (isinstance(astnode, ast.Call) and isinstance(astnode.func, ast.Name) + and astnode.func.id in sdfg.symbols): + freesyms.add(astnode.func.id) + else: + # Find all string tokens and filter them to sdfg.symbols, while ignoring connectors + codesyms = symbolic.symbols_in_code( + n.code.as_string, + potential_symbols=sdfg.symbols.keys(), + symbols_to_ignore=(n.in_connectors.keys() | n.out_connectors.keys() | n.ignored_symbols), + ) + freesyms |= codesyms + continue if hasattr(n, 'used_symbols'): freesyms |= n.used_symbols(all_symbols) @@ -462,7 +472,7 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: # Do not consider SDFG constants as symbols new_symbols.update(set(sdfg.constants.keys())) return freesyms - new_symbols - + @property def free_symbols(self) -> Set[str]: """ @@ -474,7 +484,6 @@ def free_symbols(self) -> Set[str]: """ return self.used_symbols(all_symbols=True) - def defined_symbols(self) -> Dict[str, dt.Data]: """ Returns a dictionary that maps currently-defined symbols in this SDFG @@ -535,8 +544,8 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, # Filter out memlets which go out but the same data is written to the AccessNode by another memlet for out_edge in list(out_edges): for in_edge in list(in_edges): - if (in_edge.data.data == out_edge.data.data and - in_edge.data.dst_subset.covers(out_edge.data.src_subset)): + if (in_edge.data.data == out_edge.data.data + and in_edge.data.dst_subset.covers(out_edge.data.src_subset)): out_edges.remove(out_edge) break @@ -803,7 +812,7 @@ def __init__(self, label=None, sdfg=None, debuginfo=None, location=None): self.nosync = False self.location = location if location is not None else {} self._default_lineinfo = None - + def __deepcopy__(self, memo): cls = self.__class__ result = cls.__new__(cls) diff --git a/dace/symbolic.py b/dace/symbolic.py index 0ab6e3f6ff..87fcc0036c 100644 --- a/dace/symbolic.py +++ b/dace/symbolic.py @@ -14,6 +14,7 @@ from dace import dtypes DEFAULT_SYMBOL_TYPE = dtypes.int32 +_NAME_TOKENS = re.compile(r'[a-zA-Z_][a-zA-Z_0-9]*') # NOTE: Up to (including) version 1.8, sympy.abc._clash is a dictionary of the # form {'N': sympy.abc.N, 'I': sympy.abc.I, 'pi': sympy.abc.pi} @@ -1377,6 +1378,27 @@ def equal(a: SymbolicType, b: SymbolicType, is_length: bool = True) -> Union[boo if is_length: for arg in args: facts += [sympy.Q.integer(arg), sympy.Q.positive(arg)] - + with sympy.assuming(*facts): return sympy.ask(sympy.Q.is_true(sympy.Eq(*args))) + + +def symbols_in_code(code: str, potential_symbols: Set[str] = None, + symbols_to_ignore: Set[str] = None) -> Set[str]: + """ + Tokenizes a code string for symbols and returns a set thereof. + + :param code: The code to tokenize. + :param potential_symbols: If not None, filters symbols to this given set. + :param symbols_to_ignore: If not None, filters out symbols from this set. + """ + if not code: + return set() + if potential_symbols is not None and len(potential_symbols) == 0: + # Don't bother tokenizing for an empty set of potential symbols + return set() + + tokens = set(re.findall(_NAME_TOKENS, code)) + if potential_symbols is not None: + tokens &= potential_symbols + return tokens - symbols_to_ignore diff --git a/dace/transformation/passes/prune_symbols.py b/dace/transformation/passes/prune_symbols.py index 94fcbdbc58..cf55f7a9b2 100644 --- a/dace/transformation/passes/prune_symbols.py +++ b/dace/transformation/passes/prune_symbols.py @@ -1,16 +1,13 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. import itertools -import re from dataclasses import dataclass from typing import Optional, Set, Tuple -from dace import SDFG, dtypes, properties +from dace import SDFG, dtypes, properties, symbolic from dace.sdfg import nodes from dace.transformation import pass_pipeline as ppl -_NAME_TOKENS = re.compile(r'[a-zA-Z_][a-zA-Z_0-9]*') - @dataclass(unsafe_hash=True) @properties.make_properties @@ -81,7 +78,7 @@ def used_symbols(self, sdfg: SDFG) -> Set[str]: # Add symbols in global/init/exit code for code in itertools.chain(sdfg.global_code.values(), sdfg.init_code.values(), sdfg.exit_code.values()): - result |= _symbols_in_code(code.as_string) + result |= symbolic.symbols_in_code(code.as_string) for desc in sdfg.arrays.values(): result |= set(map(str, desc.free_symbols)) @@ -94,21 +91,19 @@ def used_symbols(self, sdfg: SDFG) -> Set[str]: for node in state.nodes(): if isinstance(node, nodes.Tasklet): if node.code.language != dtypes.Language.Python: - result |= _symbols_in_code(node.code.as_string) + result |= symbolic.symbols_in_code(node.code.as_string, sdfg.symbols.keys(), + node.ignored_symbols) if node.code_global.language != dtypes.Language.Python: - result |= _symbols_in_code(node.code_global.as_string) + result |= symbolic.symbols_in_code(node.code_global.as_string, sdfg.symbols.keys(), + node.ignored_symbols) if node.code_init.language != dtypes.Language.Python: - result |= _symbols_in_code(node.code_init.as_string) + result |= symbolic.symbols_in_code(node.code_init.as_string, sdfg.symbols.keys(), + node.ignored_symbols) if node.code_exit.language != dtypes.Language.Python: - result |= _symbols_in_code(node.code_exit.as_string) - + result |= symbolic.symbols_in_code(node.code_exit.as_string, sdfg.symbols.keys(), + node.ignored_symbols) for e in sdfg.edges(): result |= e.data.free_symbols return result - -def _symbols_in_code(code: str) -> Set[str]: - if not code: - return set() - return set(re.findall(_NAME_TOKENS, code)) From 616207490cb5d990d5c00096d44fd891c8401cb5 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 24 Sep 2023 23:23:22 -0700 Subject: [PATCH 95/98] Minor fix --- dace/symbolic.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dace/symbolic.py b/dace/symbolic.py index 87fcc0036c..e9249218f9 100644 --- a/dace/symbolic.py +++ b/dace/symbolic.py @@ -1401,4 +1401,6 @@ def symbols_in_code(code: str, potential_symbols: Set[str] = None, tokens = set(re.findall(_NAME_TOKENS, code)) if potential_symbols is not None: tokens &= potential_symbols + if symbols_to_ignore is None: + return tokens return tokens - symbols_to_ignore From d2c4370e7b6554140730938f8fca6dee73da56b0 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 25 Sep 2023 01:11:48 -0700 Subject: [PATCH 96/98] Consider free symbols in SDFG init and exit code, fix None vs. empty logic bug --- dace/sdfg/sdfg.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 238d0b72c7..8388cce250 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -1341,6 +1341,12 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: defined_syms |= set(self.constants_prop.keys()) + # Add used symbols from init and exit code + for code in self.init_code.values(): + free_syms |= symbolic.symbols_in_code(code.as_string, self.symbols.keys()) + for code in self.exit_code.values(): + free_syms |= symbolic.symbols_in_code(code.as_string, self.symbols.keys()) + # Add free state symbols used_before_assignment = set() @@ -1472,7 +1478,7 @@ def init_signature(self, for_call=False, free_symbols=None) -> str: :param for_call: If True, returns arguments that can be used when calling the SDFG. """ # Get global free symbols scalar arguments - free_symbols = free_symbols or self.free_symbols + free_symbols = free_symbols if free_symbols is not None else self.used_symbols(all_symbols=False) return ", ".join( dt.Scalar(self.symbols[k]).as_arg(name=k, with_types=not for_call, for_call=for_call) for k in sorted(free_symbols) if not k.startswith('__dace')) From 2943429b6b6c33f48e22c70186cf7b7622ab2771 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 25 Sep 2023 01:34:37 -0700 Subject: [PATCH 97/98] Handle struct memlets in normalization, structure views as views --- dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index b33568e864..b5a80597c0 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -68,6 +68,10 @@ def dealias_sdfg(sdfg: SDFG): parent_arr.lifetime, parent_arr.alignment, parent_arr.debuginfo, parent_arr.total_size, parent_arr.start_offset, parent_arr.optional, parent_arr.pool) + elif isinstance(parent_arr, data.StructureView): + parent_arr = data.Structure(parent_arr.members, parent_arr.name, parent_arr.transient, + parent_arr.storage, parent_arr.location, parent_arr.lifetime, + parent_arr.debuginfo) child_names = inv_replacements[parent_name] for name in child_names: child_arr = copy.deepcopy(parent_arr) @@ -158,6 +162,8 @@ def normalize_memlet(sdfg: SDFG, state: SDFGState, original: gr.MultiConnectorEd copy.deepcopy(original.data), original.key) edge.data.try_initialize(sdfg, state, edge) + if '.' in edge.data.data and edge.data.data.startswith(data + '.'): + return edge.data if edge.data.data == data: return edge.data @@ -379,7 +385,7 @@ def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[ # 1. Check for views if isinstance(e.src, dace.nodes.AccessNode): desc = e.src.desc(sdfg) - if isinstance(desc, dace.data.View): + if isinstance(desc, (dace.data.View, dace.data.StructureView)): vedge = sdutil.get_view_edge(state, e.src) if e is vedge: viewed_node = sdutil.get_view_node(state, e.src) @@ -389,7 +395,7 @@ def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[ continue if isinstance(e.dst, dace.nodes.AccessNode): desc = e.dst.desc(sdfg) - if isinstance(desc, dace.data.View): + if isinstance(desc, (dace.data.View, dace.data.StructureView)): vedge = sdutil.get_view_edge(state, e.dst) if e is vedge: viewed_node = sdutil.get_view_node(state, e.dst) From 2935f55791ee53987708fdc70fce4b40b0c556b1 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 26 Sep 2023 08:58:21 -0700 Subject: [PATCH 98/98] Apply review comments --- dace/sdfg/analysis/schedule_tree/passes.py | 2 +- .../analysis/schedule_tree/sdfg_to_tree.py | 55 +++++++++---------- dace/sdfg/analysis/schedule_tree/treenodes.py | 15 +---- dace/sdfg/memlet_utils.py | 10 +--- dace/sdfg/nodes.py | 4 +- dace/sdfg/sdfg.py | 7 ++- tests/schedule_tree/naming_test.py | 17 ++++-- tests/sdfg/memlet_utils_test.py | 2 +- 8 files changed, 54 insertions(+), 58 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/passes.py b/dace/sdfg/analysis/schedule_tree/passes.py index feda258e20..cc33245875 100644 --- a/dace/sdfg/analysis/schedule_tree/passes.py +++ b/dace/sdfg/analysis/schedule_tree/passes.py @@ -1,4 +1,4 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ Assortment of passes for schedule trees. """ diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index b5a80597c0..917f748cb8 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -1,4 +1,4 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. from collections import defaultdict import copy from typing import Dict, List, Set @@ -28,6 +28,15 @@ def dealias_sdfg(sdfg: SDFG): + """ + Renames all data containers in an SDFG tree (i.e., nested SDFGs) to use the same data descriptors + as the top-level SDFG. This function takes care of offsetting memlets and internal + uses of arrays such that there is one naming system, and no aliasing of managed memory. + + This function operates in-place. + + :param sdfg: The SDFG to operate on. + """ for nsdfg in sdfg.all_sdfgs_recursive(): if not nsdfg.parent: @@ -243,7 +252,7 @@ def replace_memlets(sdfg: SDFG, input_mapping: Dict[str, Memlet], output_mapping repl_dict[s] = str(input_mapping[s]) # Manual replacement with strings - # TODO(later): Would be MUCH better to use e.data.replace_dict(repl_dict, replace_keys=False) + # TODO(later): Would be MUCH better to use MemletReplacer / e.data.replace_dict(repl_dict, replace_keys=False) for find, replace in repl_dict.items(): for k, v in e.data.assignments.items(): if find in v: @@ -288,8 +297,8 @@ def remove_name_collisions(sdfg: SDFG): name = new_name identifiers_seen.add(name) - # Rename duplicate symbols - for name in nsdfg.get_all_symbols(): + # Rename duplicate top-level symbols + for name in nsdfg.get_all_toplevel_symbols(): # Will already be renamed during conversion if parent_node is not None and name in parent_node.symbol_mapping: continue @@ -487,10 +496,7 @@ def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: # Create scope node and add to stack scopes.append(result) subnodes = [] - result.append(NODE_TO_SCOPE_TYPE[type(node)](node=node, - sdfg=state.parent, - top_level=False, - children=subnodes)) + result.append(NODE_TO_SCOPE_TYPE[type(node)](node=node, children=subnodes)) result = subnodes elif isinstance(node, dace.nodes.ExitNode): result = scopes.pop() @@ -540,7 +546,6 @@ def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: in_memlets = {e.dst_conn: e.data for e in state.in_edges(node) if e.dst_conn} out_memlets = {e.src_conn: e.data for e in state.out_edges(node) if e.src_conn} result.append(tn.TaskletNode(node=node, in_memlets=in_memlets, out_memlets=out_memlets)) - # result.append(tn.TaskletNode(sdfg=sdfg, node=node, in_memlets=in_memlets, out_memlets=out_memlets)) elif isinstance(node, dace.nodes.LibraryNode): # NOTE: LibraryNodes do not necessarily have connectors if node.in_connectors: @@ -552,7 +557,6 @@ def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: else: out_memlets = set([e.data for e in state.out_edges(node)]) result.append(tn.LibraryCall(node=node, in_memlets=in_memlets, out_memlets=out_memlets)) - # result.append(tn.LibraryCall(sdfg=sdfg, node=node, in_memlets=in_memlets, out_memlets=out_memlets)) elif isinstance(node, dace.nodes.AccessNode): # If one of the neighboring edges has a schedule tree node attached to it, use that # (except for views, which were generated above) @@ -650,7 +654,7 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche subnodes.extend(totree(n, node)) if not node.sequential: # Nest in general block - result = [tn.GBlock(sdfg, top_level=False, children=subnodes)] + result = [tn.GBlock(children=subnodes)] else: # Use the sub-nodes directly result = subnodes @@ -683,39 +687,32 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche if sdfg.out_degree(node.state) == 1 and parent.sequential: # Conditional state in sequential block! Add "if not condition goto exit" result.append( - tn.StateIfScope(sdfg=sdfg, - top_level=False, - condition=CodeBlock(negate_expr(e.data.condition)), + tn.StateIfScope(condition=CodeBlock(negate_expr(e.data.condition)), children=[tn.GotoNode(target=None)])) result.extend(edge_body) else: # Add "if condition" with the body above - result.append( - tn.StateIfScope(sdfg=sdfg, - top_level=False, - condition=e.data.condition, - children=edge_body)) + result.append(tn.StateIfScope(condition=e.data.condition, children=edge_body)) else: result.extend(edge_body) elif isinstance(node, cf.ForScope): - result.append(tn.ForScope(sdfg=sdfg, top_level=False, header=node, children=totree(node.body))) + result.append(tn.ForScope(header=node, children=totree(node.body))) elif isinstance(node, cf.IfScope): - result.append(tn.IfScope(sdfg=sdfg, top_level=False, condition=node.condition, children=totree(node.body))) + result.append(tn.IfScope(condition=node.condition, children=totree(node.body))) if node.orelse is not None: - result.append(tn.ElseScope(sdfg=sdfg, top_level=False, children=totree(node.orelse))) + result.append(tn.ElseScope(children=totree(node.orelse))) elif isinstance(node, cf.IfElseChain): # Add "if" for the first condition, "elif"s for the rest - result.append( - tn.IfScope(sdfg=sdfg, top_level=False, condition=node.body[0][0], children=totree(node.body[0][1]))) + result.append(tn.IfScope(condition=node.body[0][0], children=totree(node.body[0][1]))) for cond, body in node.body[1:]: - result.append(tn.ElifScope(sdfg=sdfg, top_level=False, condition=cond, children=totree(body))) + result.append(tn.ElifScope(condition=cond, children=totree(body))) # "else goto exit" - result.append(tn.ElseScope(sdfg=sdfg, top_level=False, children=[tn.GotoNode(target=None)])) + result.append(tn.ElseScope(children=[tn.GotoNode(target=None)])) elif isinstance(node, cf.WhileScope): - result.append(tn.WhileScope(sdfg=sdfg, top_level=False, header=node, children=totree(node.body))) + result.append(tn.WhileScope(header=node, children=totree(node.body))) elif isinstance(node, cf.DoWhileScope): - result.append(tn.DoWhileScope(sdfg=sdfg, top_level=False, header=node, children=totree(node.body))) + result.append(tn.DoWhileScope(header=node, children=totree(node.body))) else: # e.g., "SwitchCaseScope" raise tn.UnsupportedScopeException(type(node).__name__) @@ -726,7 +723,7 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche return result # Recursive traversal of the control flow tree - result = tn.ScheduleTreeScope(sdfg=sdfg, top_level=True, children=totree(cfg)) + result = tn.ScheduleTreeScope(children=totree(cfg)) # Clean up tree stpasses.remove_unused_and_duplicate_labels(result) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 52b6db0361..99918cd2a4 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -1,16 +1,13 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -import copy from dataclasses import dataclass, field from dace import nodes, data, subsets from dace.codegen import control_flow as cf -from dace.dtypes import TYPECLASS_TO_STRING from dace.properties import CodeBlock -from dace.sdfg import SDFG, InterstateEdge +from dace.sdfg import InterstateEdge from dace.sdfg.state import SDFGState from dace.symbolic import symbol from dace.memlet import Memlet -from functools import reduce -from typing import Dict, Iterator, List, Optional, Set, Tuple, Union +from typing import Dict, Iterator, List, Optional, Set, Union INDENTATION = ' ' @@ -21,8 +18,6 @@ class UnsupportedScopeException(Exception): @dataclass class ScheduleTreeNode: - # sdfg: SDFG - # sdfg: Optional[SDFG] = field(default=None, init=False) parent: Optional['ScheduleTreeScope'] = field(default=None, init=False) def as_string(self, indent: int = 0): @@ -37,18 +32,12 @@ def preorder_traversal(self) -> Iterator['ScheduleTreeNode']: @dataclass class ScheduleTreeScope(ScheduleTreeNode): - sdfg: SDFG - top_level: bool children: List['ScheduleTreeNode'] containers: Optional[Dict[str, data.Data]] = field(default_factory=dict, init=False) symbols: Optional[Dict[str, symbol]] = field(default_factory=dict, init=False) def __init__(self, - sdfg: Optional[SDFG] = None, - top_level: Optional[bool] = False, children: Optional[List['ScheduleTreeNode']] = None): - self.sdfg = sdfg - self.top_level = top_level self.children = children or [] if self.children: for child in children: diff --git a/dace/sdfg/memlet_utils.py b/dace/sdfg/memlet_utils.py index c0c5201e50..59a2c178d2 100644 --- a/dace/sdfg/memlet_utils.py +++ b/dace/sdfg/memlet_utils.py @@ -1,12 +1,8 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. import ast -from dace.frontend.python import astutils, memlet_parser -from dace.sdfg import SDFG, SDFGState, nodes -from dace.sdfg import graph as gr -from dace.sdfg import utils as sdutil -from dace.properties import CodeBlock -from dace import data, subsets, Memlet +from dace.frontend.python import memlet_parser +from dace import data, Memlet from typing import Callable, Dict, Optional, Set, Union diff --git a/dace/sdfg/nodes.py b/dace/sdfg/nodes.py index f60460c50e..32369a19a3 100644 --- a/dace/sdfg/nodes.py +++ b/dace/sdfg/nodes.py @@ -343,7 +343,9 @@ class Tasklet(CodeNode): 'Defaults to None, which lets the framework make assumptions based on ' 'the tasklet contents') ignored_symbols = SetProperty(element_type=str, desc='A set of symbols to ignore when computing ' - 'the symbols used by this tasklet') + 'the symbols used by this tasklet. Used to skip certain symbols in non-Python ' + 'tasklets, where only string analysis is possible; and to skip globals in Python ' + 'tasklets that should not be given as parameters to the SDFG.') def __init__(self, label, diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 8388cce250..a7b5d90b2b 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -1392,9 +1392,12 @@ def free_symbols(self) -> Set[str]: """ return self.used_symbols(all_symbols=True) - def get_all_symbols(self) -> Set[str]: + def get_all_toplevel_symbols(self) -> Set[str]: """ - Returns a set of all symbol names that are used by the SDFG. + Returns a set of all symbol names that are used by the SDFG's state machine. + This includes all symbols in the descriptor repository and interstate edges, + whether free or defined. Used to identify duplicates when, e.g., inlining or + dealiasing a set of nested SDFGs. """ # Exclude constants and data descriptor names exclude = set(self.arrays.keys()) | set(self.constants_prop.keys()) diff --git a/tests/schedule_tree/naming_test.py b/tests/schedule_tree/naming_test.py index 517716e1bc..0811682870 100644 --- a/tests/schedule_tree/naming_test.py +++ b/tests/schedule_tree/naming_test.py @@ -169,10 +169,7 @@ def test_edgecase_symbol_mapping(): assert stree.children[2].value.as_string in ('k', '(N + 1)') -def test_clash_iteration_symbols(): - sdfg = _nested_irreducible_loops() - - stree = as_schedule_tree(sdfg) +def _check_for_name_clashes(stree: tn.ScheduleTreeNode): def _traverse(node: tn.ScheduleTreeScope, scopes: List[str]): for child in node.children: @@ -181,12 +178,24 @@ def _traverse(node: tn.ScheduleTreeScope, scopes: List[str]): if itervar in scopes: raise NameError('Nested scope redefines iteration variable') _traverse(child, scopes + [itervar]) + elif isinstance(child, tn.MapScope): + itervars = child.node.map.params + if any(itervar in scopes for itervar in itervars): + raise NameError('Nested scope redefines iteration variable') + _traverse(child, scopes + itervars) elif isinstance(child, tn.ScheduleTreeScope): _traverse(child, scopes) _traverse(stree, []) +def test_clash_iteration_symbols(): + sdfg = _nested_irreducible_loops() + + stree = as_schedule_tree(sdfg) + _check_for_name_clashes(stree) + + if __name__ == '__main__': test_clash_states() test_clash_symbol_mapping(False) diff --git a/tests/sdfg/memlet_utils_test.py b/tests/sdfg/memlet_utils_test.py index 752b9ef55d..467838fc56 100644 --- a/tests/sdfg/memlet_utils_test.py +++ b/tests/sdfg/memlet_utils_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. import dace import numpy as np