diff --git a/dace/transformation/passes/analysis.py b/dace/transformation/passes/analysis.py index b59bfee5d1..86e1cde062 100644 --- a/dace/transformation/passes/analysis.py +++ b/dace/transformation/passes/analysis.py @@ -42,17 +42,62 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Set[SDFGSta # The implementation below is faster # tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) - for n, v in nx.all_pairs_shortest_path_length(sdfg.nx): - result[n] = set(t for t, l in v.items() if l > 0) - # Add self-edges - if n in sdfg.successors(n): - result[n].add(n) + for n, v in reachable_nodes(sdfg.nx): + result[n] = set(v) reachable[sdfg.sdfg_id] = result return reachable +def _single_shortest_path_length_no_self(adj, source): + """Yields (node, level) in a breadth first search, without the first level + unless a self-edge exists. + + Adapted from Shortest Path Length helper function in NetworkX. + + Parameters + ---------- + adj : dict + Adjacency dict or view + firstlevel : dict + starting nodes, e.g. {source: 1} or {target: 1} + cutoff : int or float + level at which we stop the process + """ + firstlevel = {source: 1} + + seen = {} # level (number of hops) when seen in BFS + level = 0 # the current level + nextlevel = set(firstlevel) # set of nodes to check at next level + n = len(adj) + while nextlevel: + thislevel = nextlevel # advance to next level + nextlevel = set() # and start a new set (fringe) + found = [] + for v in thislevel: + if v not in seen: + if level == 0 and v is source: # Skip 0-length path to self + found.append(v) + continue + seen[v] = level # set the level of vertex v + found.append(v) + yield (v, level) + if len(seen) == n: + return + for v in found: + nextlevel.update(adj[v]) + level += 1 + del seen + + +def reachable_nodes(G): + """Computes the reachable nodes in G.""" + adj = G.adj + for n in G: + yield (n, dict(_single_shortest_path_length_no_self(adj, n))) + + @properties.make_properties class SymbolAccessSets(ppl.Pass): """