Skip to content

Commit

Permalink
fixup! fixup! fixup! Issue #150 WIP more advanced pg splitting
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Sep 16, 2024
1 parent 7276f0b commit 46b5dae
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 12 deletions.
22 changes: 10 additions & 12 deletions src/openeo_aggregator/partitionedjobs/crossbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ class _FrozenNode:
"""

# TODO: instead of frozen dataclass: have __init__ with some type casting/validation. Or use attrs?
# TODO: better name for this class?

# Node ids of other nodes this node depends on (aka parents)
depends_on: frozenset[NodeId]
Expand All @@ -406,6 +407,7 @@ class _FrozenNode:

# Backend ids this node is marked to be supported on
# value None means it is unknown/unconstrained for this node
# TODO: Move this to _FrozenGraph as responsibility?
backend_candidates: Union[frozenset[BackendId], None]

def __repr__(self):
Expand All @@ -418,15 +420,6 @@ def __repr__(self):
+ [f">{f}" for f in self.flows_to]
)

def clone(self, backend_candidates: Union[frozenset[BackendId], None] = _UNSET) -> "_FrozenNode":
"""Clone this node, optionally updating backend_candidates"""
backend_candidates = self.backend_candidates if backend_candidates is _UNSET else backend_candidates
return _FrozenNode(
depends_on=self.depends_on,
flows_to=self.flows_to,
backend_candidates=backend_candidates,
)


class _FrozenGraph:
"""
Expand Down Expand Up @@ -609,9 +602,10 @@ def get_flow_weights(node_id: NodeId) -> Dict[NodeId, fractions.Fraction]:
# Select articulation points: nodes where all flows have weight 1
return set(node_id for node_id, flows in flow_weights.items() if all(w == 1 for w in flows.values()))

def _split_at(self, split_node_id: NodeId) -> Tuple[_FrozenGraph, _FrozenGraph]:
def split_at(self, split_node_id: NodeId) -> Tuple[_FrozenGraph, _FrozenGraph]:
"""
Split graph at given node id
Split graph at given node id (must be articulation point),
creating two new graphs, containing original nodes and adaptation of the split node.
"""
split_node = self.node(split_node_id)

Expand All @@ -627,6 +621,10 @@ def next_nodes(node_id: NodeId) -> Iterable[NodeId]:
return node.depends_on.union(node.flows_to)

up_node_ids = set(self._walk(seeds=[split_node_id], next_nodes=next_nodes))

if split_node.flows_to.intersection(up_node_ids):
raise GraphSplitException(f"Graph can not be split at {split_node_id}: not an articulation point.")

up_graph = {n: self.node(n) for n in up_node_ids}
up_graph[split_node_id] = _FrozenNode(
depends_on=split_node.depends_on,
Expand Down Expand Up @@ -673,7 +671,7 @@ def produce_split_locations(self, limit: int = 2) -> Iterator[List[NodeId]]:
# TODO: smarter picking of split node (e.g. one with most upstream nodes)
for split_node_id in split_options[:limit]:
# Split graph at this articulation point
down, up = self._split_at(split_node_id)
down, up = self.split_at(split_node_id)
if down.find_forsaken_nodes():
down_splits = list(down.produce_split_locations(limit=limit - 1))
else:
Expand Down
72 changes: 72 additions & 0 deletions tests/partitionedjobs/test_crossbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from openeo_aggregator.partitionedjobs import PartitionedJob, SubJob
from openeo_aggregator.partitionedjobs.crossbackend import (
CrossBackendSplitter,
GraphSplitException,
SubGraphId,
_FrozenGraph,
_FrozenNode,
Expand Down Expand Up @@ -615,6 +616,77 @@ def test_find_articulation_points(self, flat, expected):
graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={})
assert graph.find_articulation_points() == expected

def test_split_at_minimal(self):
graph = _FrozenGraph.from_edges([("a", "b")], backend_candidates_map={"a": "A"})
# Split at a
down, up = graph.split_at("a")
assert sorted(up.iter_nodes()) == [
("a", _FrozenNode(frozenset(), frozenset(), backend_candidates=frozenset(["A"]))),
]
assert sorted(down.iter_nodes()) == [
("a", _FrozenNode(frozenset(), frozenset(["b"]), backend_candidates=None)),
("b", _FrozenNode(frozenset(["a"]), frozenset([]), backend_candidates=None)),
]
# Split at b
down, up = graph.split_at("b")
assert sorted(up.iter_nodes()) == [
("a", _FrozenNode(frozenset(), frozenset(["b"]), backend_candidates=frozenset(["A"]))),
("b", _FrozenNode(frozenset(["a"]), frozenset([]), backend_candidates=None)),
]
assert sorted(down.iter_nodes()) == [
("b", _FrozenNode(frozenset(), frozenset(), backend_candidates=None)),
]

def test_split_at_basic(self):
graph = _FrozenGraph.from_edges([("a", "b"), ("b", "c")], backend_candidates_map={"a": "A"})
down, up = graph.split_at("b")
assert sorted(up.iter_nodes()) == [
("a", _FrozenNode(frozenset(), frozenset(["b"]), backend_candidates=frozenset(["A"]))),
("b", _FrozenNode(frozenset(["a"]), frozenset([]), backend_candidates=None)),
]
assert sorted(down.iter_nodes()) == [
("b", _FrozenNode(frozenset(), frozenset(["c"]), backend_candidates=None)),
("c", _FrozenNode(frozenset(["b"]), frozenset([]), backend_candidates=None)),
]

def test_split_at_complex(self):
graph = _FrozenGraph.from_edges(
[("a", "b"), ("a", "c"), ("b", "d"), ("c", "d"), ("c", "e"), ("e", "g"), ("f", "g"), ("X", "Y")]
)
down, up = graph.split_at("e")
assert sorted(up.iter_nodes()) == sorted(
_FrozenGraph.from_edges([("a", "b"), ("a", "c"), ("b", "d"), ("c", "d"), ("c", "e")]).iter_nodes()
)
assert sorted(down.iter_nodes()) == sorted(
_FrozenGraph.from_edges([("e", "g"), ("f", "g"), ("X", "Y")]).iter_nodes()
)

def test_split_at_non_articulation_point(self):
graph = _FrozenGraph.from_edges([("a", "b"), ("b", "c"), ("a", "c")])
with pytest.raises(GraphSplitException, match="not an articulation point"):
_ = graph.split_at("b")

# These should still work
down, up = graph.split_at("a")
assert sorted(up.iter_nodes()) == [
("a", _FrozenNode(frozenset(), frozenset(), backend_candidates=None)),
]
assert sorted(down.iter_nodes()) == [
("a", _FrozenNode(frozenset(), frozenset(["b", "c"]), backend_candidates=None)),
("b", _FrozenNode(frozenset(["a"]), frozenset(["c"]), backend_candidates=None)),
("c", _FrozenNode(frozenset(["a", "b"]), frozenset(), backend_candidates=None)),
]

down, up = graph.split_at("c")
assert sorted(up.iter_nodes()) == [
("a", _FrozenNode(frozenset(), frozenset(["b", "c"]), backend_candidates=None)),
("b", _FrozenNode(frozenset(["a"]), frozenset(["c"]), backend_candidates=None)),
("c", _FrozenNode(frozenset(["a", "b"]), frozenset(), backend_candidates=None)),
]
assert sorted(down.iter_nodes()) == [
("c", _FrozenNode(frozenset(), frozenset(), backend_candidates=None)),
]

def test_produce_split_locations_simple(self):
"""Simple produce_split_locations use case: no need for splits"""
flat = {
Expand Down

0 comments on commit 46b5dae

Please sign in to comment.