Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better cycle detection #167

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 66 additions & 10 deletions pytools/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
.. autoexception:: CycleError
.. autofunction:: compute_topological_order
.. autofunction:: compute_transitive_closure
.. autofunction:: find_cycles
.. autofunction:: contains_cycle
.. autofunction:: compute_induced_subgraph
.. autofunction:: as_graphviz_dot
Expand All @@ -68,6 +69,8 @@
Mapping, MutableSet, Optional, Set, Tuple, TypeVar)


from enum import Enum

try:
from typing import TypeAlias
except ImportError:
Expand Down Expand Up @@ -242,6 +245,52 @@ def __init__(self, node: NodeT) -> None:
self.node = node


class _NodeState(Enum):
WHITE = 0 # Not visited yet
GREY = 1 # Currently visiting
BLACK = 2 # Done visiting
Comment on lines +249 to +251
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use descriptive names for the node state?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you prefer, I can rename these, but I thought white/grey/black were standard labels in DFS (see e.g. http://www.cs.cmu.edu/afs/cs/academic/class/15750-s17/ScribeNotes/lecture9.pdf)



def find_cycles(graph: GraphT, all_cycles: bool = True) -> List[List[NodeT]]:
"""
Find cycles in *graph* using DFS.

:arg all_cycles: If False, only return the first cycle found.

:returns: A :class:`list` in which each element represents another :class:`list`
of nodes that form a cycle.
"""
def dfs(node: NodeT, path: List[NodeT]) -> List[NodeT]:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Constructing path just in case is wasteful IMO: The path could be collected as you return, if a cycle is found.

Copy link
Contributor Author

@matthiasdiener matthiasdiener Dec 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of 96bd3a6 + 3640c88?

# Cycle detected
if visited[node] == _NodeState.GREY:
return path + [node]

# Visit this node, explore its children
visited[node] = _NodeState.GREY
for child in graph[node]:
if visited[child] != _NodeState.BLACK and dfs(child, path):
return path + [node] + (
[child] if child != node else [])

# Done visiting node
visited[node] = _NodeState.BLACK
return []

visited = {node: _NodeState.WHITE for node in graph.keys()}

res = []

for node in graph:
if visited[node] == _NodeState.WHITE:
cycle = dfs(node, [])
if cycle:
res.append(cycle)
if not all_cycles:
return res

return res


class HeapEntry:
"""
Helper class to compare associated keys while comparing the elements in
Expand All @@ -259,14 +308,17 @@ def __lt__(self, other: "HeapEntry") -> bool:


def compute_topological_order(graph: GraphT[NodeT],
key: Optional[Callable[[NodeT], Any]] = None) \
-> List[NodeT]:
key: Optional[Callable[[NodeT], Any]] = None,
verbose_cycle: bool = True) -> List[NodeT]:
"""Compute a topological order of nodes in a directed graph.

:arg key: A custom key function may be supplied to determine the order in
break-even cases. Expects a function of one argument that is used to
extract a comparison key from each node of the *graph*.

:arg verbose_cycle: Verbose reporting in case *graph* contains a cycle, i.e.
return a :class:`CycleError` which has a node that is part of a cycle.

:returns: A :class:`list` representing a valid topological ordering of the
nodes in the directed graph.

Expand Down Expand Up @@ -318,9 +370,17 @@ def compute_topological_order(graph: GraphT[NodeT],
heappush(heap, HeapEntry(child, keyfunc(child)))

if len(order) != total_num_nodes:
# any node which has a predecessor left is a part of a cycle
raise CycleError(next(iter(n for n, num_preds in
nodes_to_num_predecessors.items() if num_preds != 0)))
# There is a cycle in the graph
inducer marked this conversation as resolved.
Show resolved Hide resolved
if not verbose_cycle:
raise CycleError(None)

try:
cycles: List[List[NodeT]] = find_cycles(graph)
except KeyError:
# Graph is invalid
raise CycleError(None)
else:
raise CycleError(cycles[0][0])
Comment on lines +381 to +383
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add to the documentation of CycleError what the value might mean.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, the current doc of CycleError has :attr node: Node in a directed graph that is part of a cycle. - I'm not sure what else to add there.


return order

Expand Down Expand Up @@ -373,11 +433,7 @@ def contains_cycle(graph: GraphT[NodeT]) -> bool:
.. versionadded:: 2020.2
"""

try:
compute_topological_order(graph)
return False
except CycleError:
return True
return bool(find_cycles(graph, all_cycles=False))

# }}}

Expand Down
32 changes: 32 additions & 0 deletions test/test_graph_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,38 @@ def test_is_connected():
assert is_connected({})


def test_find_cycles():
from pytools.graph import compute_topological_order, CycleError, find_cycles

# Non-Self Loop
graph = {1: {}, 5: {1, 8}, 8: {5}}
assert find_cycles(graph) == [[5, 8]]
with pytest.raises(CycleError, match="5|8"):
compute_topological_order(graph)

# Self-Loop
graph = {1: {1}, 5: {8}, 8: {}}
assert find_cycles(graph) == [[1]]
with pytest.raises(CycleError, match="1"):
compute_topological_order(graph)

# Invalid graph with loop
graph = {1: {42}, 5: {8}, 8: {5}}
# Can't run find_cycles on this graph since it is invalid
with pytest.raises(CycleError, match="None"):
compute_topological_order(graph)

# Multiple loops
graph = {1: {1}, 5: {8}, 8: {5}}
assert find_cycles(graph) == [[1], [5, 8]]
with pytest.raises(CycleError, match="1"):
compute_topological_order(graph)

# Cycle over multiple nodes
graph = {4: {2}, 2: {3}, 3: {4}}
assert find_cycles(graph) == [[4, 2, 3]]


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down