Skip to content

Commit

Permalink
Add graph utilities for tag propagation (#220)
Browse files Browse the repository at this point in the history
  • Loading branch information
a-alveyblanc authored Sep 4, 2024
1 parent cee5027 commit 5ab81ef
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 9 deletions.
75 changes: 66 additions & 9 deletions pytools/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Copyright (C) 2009-2013 Andreas Kloeckner
Copyright (C) 2020 Matt Wala
Copyright (C) 2020 James Stevens
Copyright (C) 2024 Addison Alvey-Blanco
"""

__license__ = """
Expand Down Expand Up @@ -47,6 +48,8 @@
.. autofunction:: as_graphviz_dot
.. autofunction:: validate_graph
.. autofunction:: is_connected
.. autofunction:: undirected_graph_from_edges
.. autofunction:: get_reachable_nodes
Type Variables Used
-------------------
Expand All @@ -71,13 +74,16 @@
Callable,
Collection,
Dict,
FrozenSet,
Generic,
Hashable,
Iterable,
Iterator,
List,
Mapping,
MutableSet,
Optional,
Protocol,
Set,
Tuple,
TypeVar,
Expand All @@ -98,7 +104,6 @@

NodeT = TypeVar("NodeT", bound=Hashable)


GraphT: TypeAlias[NodeT] = Mapping[NodeT, Collection[NodeT]]


Expand Down Expand Up @@ -263,8 +268,13 @@ def __init__(self, node: NodeT) -> None:
self.node = node


class _SupportsLT(Protocol):
def __lt__(self, other: object) -> bool:
...


@dataclass(frozen=True)
class HeapEntry(Generic[NodeT]):
class _HeapEntry(Generic[NodeT]):
"""
Helper class to compare associated keys while comparing the elements in
heap operations.
Expand All @@ -273,9 +283,9 @@ class HeapEntry(Generic[NodeT]):
<https://github.com/python/cpython/blob/8d21aa21f2cbc6d50aab3f420bb23be1d081dac4/Lib/heapq.py#L135-L138>.
"""
node: NodeT
key: Any
key: _SupportsLT

def __lt__(self, other: HeapEntry) -> bool:
def __lt__(self, other: _HeapEntry[NodeT]) -> bool:
return self.key < other.key


Expand Down Expand Up @@ -321,7 +331,7 @@ def compute_topological_order(graph: GraphT[NodeT],
# heap: list of instances of HeapEntry(n) where 'n' is a node in
# 'graph' with no predecessor. Nodes with no predecessors are the
# schedulable candidates.
heap = [HeapEntry(n, keyfunc(n))
heap = [_HeapEntry(n, keyfunc(n))
for n, num_preds in nodes_to_num_predecessors.items()
if num_preds == 0]
heapify(heap)
Expand All @@ -336,7 +346,7 @@ def compute_topological_order(graph: GraphT[NodeT],
for child in graph.get(node_to_be_scheduled, ()):
nodes_to_num_predecessors[child] -= 1
if nodes_to_num_predecessors[child] == 0:
heappush(heap, HeapEntry(child, keyfunc(child)))
heappush(heap, _HeapEntry(child, keyfunc(child)))

if len(order) != total_num_nodes:
# any node which has a predecessor left is a part of a cycle
Expand Down Expand Up @@ -457,11 +467,11 @@ def as_graphviz_dot(graph: GraphT[NodeT],
from pytools.graphviz import dot_escape

if node_labels is None:
def node_labels(x):
def node_labels(x: NodeT) -> str:
return str(x)

if edge_labels is None:
def edge_labels(x, y):
def edge_labels(x: NodeT, y: NodeT) -> str:
return ""

node_to_id = {}
Expand Down Expand Up @@ -511,7 +521,7 @@ def validate_graph(graph: GraphT[NodeT]) -> None:
# }}}


# {{{
# {{{ is_connected

def is_connected(graph: GraphT[NodeT]) -> bool:
"""
Expand Down Expand Up @@ -542,5 +552,52 @@ def dfs(node: NodeT) -> None:

return visited == graph.keys()

# }}}


def undirected_graph_from_edges(
edges: Iterable[Tuple[NodeT, NodeT]],
) -> GraphT[NodeT]:
"""
Constructs an undirected graph using *edges*.
:arg edges: An :class:`Iterable` of pairs of related :class:`NodeT` s.
:returns: A :class:`GraphT` that is the undirected graph.
"""
undirected_graph: Dict[NodeT, Set[NodeT]] = {}

for lhs, rhs in edges:
if lhs == rhs:
raise TypeError("Found loop in edges,"
f" LHS, RHS = {lhs}")

undirected_graph.setdefault(lhs, set()).add(rhs)
undirected_graph.setdefault(rhs, set()).add(lhs)

return undirected_graph


def get_reachable_nodes(
undirected_graph: GraphT[NodeT],
source_node: NodeT) -> FrozenSet[NodeT]:
"""
Returns a :class:`frozenset` of all nodes in *undirected_graph* that are
reachable from *source_node*.
"""
nodes_visited: Set[NodeT] = set()
nodes_to_visit = {source_node}

while nodes_to_visit:
current_node = nodes_to_visit.pop()
nodes_visited.add(current_node)

neighbors = undirected_graph[current_node]
nodes_to_visit.update({node
for node in neighbors
if node not in nodes_visited})

return frozenset(nodes_visited)


# vim: foldmethod=marker
32 changes: 32 additions & 0 deletions pytools/test/test_graph_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,38 @@ def test_is_connected():
assert is_connected({})


def test_propagation_graph_tools():
from pytools.graph import (
get_reachable_nodes,
undirected_graph_from_edges,
)

vars = {"a", "b", "c", "d", "e", "f", "g"}

constraints = [
("a", "b"),
("a", "d"),
("c", "d"),
("e", "f"),
("f", "g")
]

all_reachable_nodes = {
"a": frozenset({"a", "b", "c", "d"}),
"b": frozenset({"a", "b", "c", "d"}),
"c": frozenset({"a", "b", "c", "d"}),
"e": frozenset({"e", "f", "g"}),
"f": frozenset({"e", "f", "g"}),
"g": frozenset({"e", "f", "g"})
}

propagation_graph = undirected_graph_from_edges(constraints)
assert (
all_reachable_nodes[var] == get_reachable_nodes(propagation_graph, var)
for var in vars
)


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down
1 change: 1 addition & 0 deletions run-mypy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ mypy --show-error-codes pytools

mypy --strict --follow-imports=silent \
pytools/tag.py \
pytools/graph.py \
pytools/datatable.py \
pytools/persistent_dict.py

0 comments on commit 5ab81ef

Please sign in to comment.