diff --git a/dace/transformation/passes/array_elimination.py b/dace/transformation/passes/array_elimination.py index e8cb44b46d..803c81b21e 100644 --- a/dace/transformation/passes/array_elimination.py +++ b/dace/transformation/passes/array_elimination.py @@ -69,9 +69,6 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[S removed_nodes = self.merge_access_nodes(state, access_nodes, lambda n: state.in_degree(n) == 0) removed_nodes |= self.merge_access_nodes(state, access_nodes, lambda n: state.out_degree(n) == 0) - # Update access nodes with merged nodes - access_nodes = {k: [n for n in v if n not in removed_nodes] for k, v in access_nodes.items()} - # Remove redundant views removed_nodes |= self.remove_redundant_views(sdfg, state, access_nodes) @@ -108,7 +105,8 @@ def merge_access_nodes(self, state: SDFGState, access_nodes: Dict[str, List[node Merges access nodes that follow the same conditions together to the first access node. """ removed_nodes: Set[nodes.AccessNode] = set() - for nodeset in access_nodes.values(): + for data_container in access_nodes.keys(): + nodeset = access_nodes[data_container] if len(nodeset) > 1: # Merge all other access nodes to the first one that fits the condition, if one exists. first_node = None @@ -159,6 +157,7 @@ def merge_access_nodes(self, state: SDFGState, access_nodes: Dict[str, List[node # Remove merged node and associated edges state.remove_node(node) removed_nodes.add(node) + access_nodes[data_container] = [n for n in nodeset if n not in removed_nodes] return removed_nodes def remove_redundant_views(self, sdfg: SDFG, state: SDFGState, access_nodes: Dict[str, List[nodes.AccessNode]]):