Skip to content

Commit

Permalink
Fix stale access node store
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Dec 16, 2024
1 parent 93f2387 commit 9641e23
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions dace/transformation/passes/array_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[S
removed_nodes = self.merge_access_nodes(state, access_nodes, lambda n: state.in_degree(n) == 0)
removed_nodes |= self.merge_access_nodes(state, access_nodes, lambda n: state.out_degree(n) == 0)

# Update access nodes with merged nodes
access_nodes = {k: [n for n in v if n not in removed_nodes] for k, v in access_nodes.items()}

# Remove redundant views
removed_nodes |= self.remove_redundant_views(sdfg, state, access_nodes)

Expand Down Expand Up @@ -108,7 +105,8 @@ def merge_access_nodes(self, state: SDFGState, access_nodes: Dict[str, List[node
Merges access nodes that follow the same conditions together to the first access node.
"""
removed_nodes: Set[nodes.AccessNode] = set()
for nodeset in access_nodes.values():
for data_container in access_nodes.keys():
nodeset = access_nodes[data_container]
if len(nodeset) > 1:
# Merge all other access nodes to the first one that fits the condition, if one exists.
first_node = None
Expand Down Expand Up @@ -159,6 +157,7 @@ def merge_access_nodes(self, state: SDFGState, access_nodes: Dict[str, List[node
# Remove merged node and associated edges
state.remove_node(node)
removed_nodes.add(node)
access_nodes[data_container] = [n for n in nodeset if n not in removed_nodes]
return removed_nodes

def remove_redundant_views(self, sdfg: SDFG, state: SDFGState, access_nodes: Dict[str, List[nodes.AccessNode]]):
Expand Down

0 comments on commit 9641e23

Please sign in to comment.