Skip to content

Commit

Permalink
add find_cycles and use it
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Dec 12, 2022
1 parent e9d1f57 commit 27255a5
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 20 deletions.
61 changes: 43 additions & 18 deletions pytools/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
.. autoexception:: CycleError
.. autofunction:: compute_topological_order
.. autofunction:: compute_transitive_closure
.. autofunction:: find_cycles
.. autofunction:: contains_cycle
.. autofunction:: compute_induced_subgraph
.. autofunction:: validate_graph
Expand Down Expand Up @@ -240,6 +241,42 @@ def __init__(self, node: NodeT) -> None:
self.node = node


def find_cycles(graph: GraphT) -> List[List[NodeT]]:
"""
Find all cycles in *graph* using DFS.
:returns: A :class:`list` in which each element represents another :class:`list`
of nodes that form a cycle.
"""
def dfs(node: NodeT, path: List[NodeT]) -> List[NodeT]:
# Cycle detected
if visited[node] == 1:
return path

# Visit this node, explore its children
visited[node] = 1
path.append(node)
for child in graph[node]:
if visited[child] != 2 and dfs(child, path):
return path

# Done visiting node
visited[node] = 2
return []

visited = {node: 0 for node in graph.keys()}

res = []

for node in graph:
if not visited[node]:
cycle = dfs(node, [])
if cycle:
res.append(cycle)

return res


class HeapEntry:
"""
Helper class to compare associated keys while comparing the elements in
Expand All @@ -257,8 +294,8 @@ def __lt__(self, other: "HeapEntry") -> bool:


def compute_topological_order(graph: GraphT,
key: Optional[Callable[[T], Any]] = None,
verbose_cycle: bool = True) -> List[T]:
key: Optional[Callable[[NodeT], Any]] = None,
verbose_cycle: bool = True) -> List[NodeT]:
"""Compute a topological order of nodes in a directed graph.
:arg key: A custom key function may be supplied to determine the order in
Expand Down Expand Up @@ -323,24 +360,12 @@ def compute_topological_order(graph: GraphT,
raise CycleError(None)

try:
validate_graph(graph)
except ValueError:
# Graph is invalid, we can't compute SCCs or return a meaningful node
# that is part of a cycle
cycles = find_cycles(graph)
except KeyError:
# Graph is invalid
raise CycleError(None)

sccs = compute_sccs(graph)
cycles = [scc for scc in sccs if len(scc) > 1]

if cycles:
# Cycles that are not self-loops
node = cycles[0][0]
else:
# Self-loop SCCs also have a length of 1
node = next(iter(n for n, num_preds in
nodes_to_num_predecessors.items() if num_preds != 0))

raise CycleError(node)
raise CycleError(cycles[0][0])

return order

Expand Down
13 changes: 11 additions & 2 deletions test/test_graph_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,24 +395,33 @@ def test_is_connected():
assert is_connected({})


def test_cycle_detection():
from pytools.graph import compute_topological_order, CycleError
def test_find_cycles():
from pytools.graph import compute_topological_order, CycleError, find_cycles

# Non-Self Loop
graph = {1: {}, 5: {1, 8}, 8: {5}}
assert find_cycles(graph) == [[5, 8]]
with pytest.raises(CycleError, match="5|8"):
compute_topological_order(graph)

# Self-Loop
graph = {1: {1}, 5: {8}, 8: {}}
assert find_cycles(graph) == [[1]]
with pytest.raises(CycleError, match="1"):
compute_topological_order(graph)

# Invalid graph with loop
graph = {1: {42}, 5: {8}, 8: {5}}
# Can't run find_cycles on this graph since it is invalid
with pytest.raises(CycleError, match="None"):
compute_topological_order(graph)

# Multiple loops
graph = {1: {1}, 5: {8}, 8: {5}}
assert find_cycles(graph) == [[1], [5, 8]]
with pytest.raises(CycleError, match="1"):
compute_topological_order(graph)


if __name__ == "__main__":
if len(sys.argv) > 1:
Expand Down

0 comments on commit 27255a5

Please sign in to comment.