Skip to content

Commit

Permalink
Fix edge case
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun committed Sep 20, 2023
1 parent 9d27e72 commit aac7013
Showing 1 changed file with 50 additions and 5 deletions.
55 changes: 50 additions & 5 deletions dace/transformation/passes/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,62 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Set[SDFGSta
# The implementation below is faster
# tc: nx.DiGraph = nx.transitive_closure(sdfg.nx)

for n, v in nx.all_pairs_shortest_path_length(sdfg.nx):
result[n] = set(t for t, l in v.items() if l > 0)
# Add self-edges
if n in sdfg.successors(n):
result[n].add(n)
for n, v in reachable_nodes(sdfg.nx):
result[n] = set(v)

reachable[sdfg.sdfg_id] = result

return reachable


def _single_shortest_path_length_no_self(adj, source):
"""Yields (node, level) in a breadth first search, without the first level
unless a self-edge exists.
Adapted from Shortest Path Length helper function in NetworkX.
Parameters
----------
adj : dict
Adjacency dict or view
firstlevel : dict
starting nodes, e.g. {source: 1} or {target: 1}
cutoff : int or float
level at which we stop the process
"""
firstlevel = {source: 1}

seen = {} # level (number of hops) when seen in BFS
level = 0 # the current level
nextlevel = set(firstlevel) # set of nodes to check at next level
n = len(adj)
while nextlevel:
thislevel = nextlevel # advance to next level
nextlevel = set() # and start a new set (fringe)
found = []
for v in thislevel:
if v not in seen:
if level == 0 and v is source: # Skip 0-length path to self
found.append(v)
continue
seen[v] = level # set the level of vertex v
found.append(v)
yield (v, level)
if len(seen) == n:
return
for v in found:
nextlevel.update(adj[v])
level += 1
del seen


def reachable_nodes(G):
"""Computes the reachable nodes in G."""
adj = G.adj
for n in G:
yield (n, dict(_single_shortest_path_length_no_self(adj, n)))


@properties.make_properties
class SymbolAccessSets(ppl.Pass):
"""
Expand Down

0 comments on commit aac7013

Please sign in to comment.