diff --git a/src/openeo_aggregator/partitionedjobs/crossbackend.py b/src/openeo_aggregator/partitionedjobs/crossbackend.py index f1229fc..40ce637 100644 --- a/src/openeo_aggregator/partitionedjobs/crossbackend.py +++ b/src/openeo_aggregator/partitionedjobs/crossbackend.py @@ -398,6 +398,7 @@ class _FrozenNode: """ # TODO: instead of frozen dataclass: have __init__ with some type casting/validation. Or use attrs? + # TODO: better name for this class? # Node ids of other nodes this node depends on (aka parents) depends_on: frozenset[NodeId] @@ -406,6 +407,7 @@ class _FrozenNode: # Backend ids this node is marked to be supported on # value None means it is unknown/unconstrained for this node + # TODO: Move this to _FrozenGraph as responsibility? backend_candidates: Union[frozenset[BackendId], None] def __repr__(self): @@ -418,15 +420,6 @@ def __repr__(self): + [f">{f}" for f in self.flows_to] ) - def clone(self, backend_candidates: Union[frozenset[BackendId], None] = _UNSET) -> "_FrozenNode": - """Clone this node, optionally updating backend_candidates""" - backend_candidates = self.backend_candidates if backend_candidates is _UNSET else backend_candidates - return _FrozenNode( - depends_on=self.depends_on, - flows_to=self.flows_to, - backend_candidates=backend_candidates, - ) - class _FrozenGraph: """ @@ -609,9 +602,10 @@ def get_flow_weights(node_id: NodeId) -> Dict[NodeId, fractions.Fraction]: # Select articulation points: nodes where all flows have weight 1 return set(node_id for node_id, flows in flow_weights.items() if all(w == 1 for w in flows.values())) - def _split_at(self, split_node_id: NodeId) -> Tuple[_FrozenGraph, _FrozenGraph]: + def split_at(self, split_node_id: NodeId) -> Tuple[_FrozenGraph, _FrozenGraph]: """ - Split graph at given node id + Split graph at given node id (must be articulation point), + creating two new graphs, containing original nodes and adaptation of the split node. """ split_node = self.node(split_node_id) @@ -627,6 +621,10 @@ def next_nodes(node_id: NodeId) -> Iterable[NodeId]: return node.depends_on.union(node.flows_to) up_node_ids = set(self._walk(seeds=[split_node_id], next_nodes=next_nodes)) + + if split_node.flows_to.intersection(up_node_ids): + raise GraphSplitException(f"Graph can not be split at {split_node_id}: not an articulation point.") + up_graph = {n: self.node(n) for n in up_node_ids} up_graph[split_node_id] = _FrozenNode( depends_on=split_node.depends_on, @@ -673,7 +671,7 @@ def produce_split_locations(self, limit: int = 2) -> Iterator[List[NodeId]]: # TODO: smarter picking of split node (e.g. one with most upstream nodes) for split_node_id in split_options[:limit]: # Split graph at this articulation point - down, up = self._split_at(split_node_id) + down, up = self.split_at(split_node_id) if down.find_forsaken_nodes(): down_splits = list(down.produce_split_locations(limit=limit - 1)) else: diff --git a/tests/partitionedjobs/test_crossbackend.py b/tests/partitionedjobs/test_crossbackend.py index 0d0c935..f2c8c85 100644 --- a/tests/partitionedjobs/test_crossbackend.py +++ b/tests/partitionedjobs/test_crossbackend.py @@ -15,6 +15,7 @@ from openeo_aggregator.partitionedjobs import PartitionedJob, SubJob from openeo_aggregator.partitionedjobs.crossbackend import ( CrossBackendSplitter, + GraphSplitException, SubGraphId, _FrozenGraph, _FrozenNode, @@ -615,6 +616,77 @@ def test_find_articulation_points(self, flat, expected): graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={}) assert graph.find_articulation_points() == expected + def test_split_at_minimal(self): + graph = _FrozenGraph.from_edges([("a", "b")], backend_candidates_map={"a": "A"}) + # Split at a + down, up = graph.split_at("a") + assert sorted(up.iter_nodes()) == [ + ("a", _FrozenNode(frozenset(), frozenset(), backend_candidates=frozenset(["A"]))), + ] + assert sorted(down.iter_nodes()) == [ + ("a", _FrozenNode(frozenset(), frozenset(["b"]), backend_candidates=None)), + ("b", _FrozenNode(frozenset(["a"]), frozenset([]), backend_candidates=None)), + ] + # Split at b + down, up = graph.split_at("b") + assert sorted(up.iter_nodes()) == [ + ("a", _FrozenNode(frozenset(), frozenset(["b"]), backend_candidates=frozenset(["A"]))), + ("b", _FrozenNode(frozenset(["a"]), frozenset([]), backend_candidates=None)), + ] + assert sorted(down.iter_nodes()) == [ + ("b", _FrozenNode(frozenset(), frozenset(), backend_candidates=None)), + ] + + def test_split_at_basic(self): + graph = _FrozenGraph.from_edges([("a", "b"), ("b", "c")], backend_candidates_map={"a": "A"}) + down, up = graph.split_at("b") + assert sorted(up.iter_nodes()) == [ + ("a", _FrozenNode(frozenset(), frozenset(["b"]), backend_candidates=frozenset(["A"]))), + ("b", _FrozenNode(frozenset(["a"]), frozenset([]), backend_candidates=None)), + ] + assert sorted(down.iter_nodes()) == [ + ("b", _FrozenNode(frozenset(), frozenset(["c"]), backend_candidates=None)), + ("c", _FrozenNode(frozenset(["b"]), frozenset([]), backend_candidates=None)), + ] + + def test_split_at_complex(self): + graph = _FrozenGraph.from_edges( + [("a", "b"), ("a", "c"), ("b", "d"), ("c", "d"), ("c", "e"), ("e", "g"), ("f", "g"), ("X", "Y")] + ) + down, up = graph.split_at("e") + assert sorted(up.iter_nodes()) == sorted( + _FrozenGraph.from_edges([("a", "b"), ("a", "c"), ("b", "d"), ("c", "d"), ("c", "e")]).iter_nodes() + ) + assert sorted(down.iter_nodes()) == sorted( + _FrozenGraph.from_edges([("e", "g"), ("f", "g"), ("X", "Y")]).iter_nodes() + ) + + def test_split_at_non_articulation_point(self): + graph = _FrozenGraph.from_edges([("a", "b"), ("b", "c"), ("a", "c")]) + with pytest.raises(GraphSplitException, match="not an articulation point"): + _ = graph.split_at("b") + + # These should still work + down, up = graph.split_at("a") + assert sorted(up.iter_nodes()) == [ + ("a", _FrozenNode(frozenset(), frozenset(), backend_candidates=None)), + ] + assert sorted(down.iter_nodes()) == [ + ("a", _FrozenNode(frozenset(), frozenset(["b", "c"]), backend_candidates=None)), + ("b", _FrozenNode(frozenset(["a"]), frozenset(["c"]), backend_candidates=None)), + ("c", _FrozenNode(frozenset(["a", "b"]), frozenset(), backend_candidates=None)), + ] + + down, up = graph.split_at("c") + assert sorted(up.iter_nodes()) == [ + ("a", _FrozenNode(frozenset(), frozenset(["b", "c"]), backend_candidates=None)), + ("b", _FrozenNode(frozenset(["a"]), frozenset(["c"]), backend_candidates=None)), + ("c", _FrozenNode(frozenset(["a", "b"]), frozenset(), backend_candidates=None)), + ] + assert sorted(down.iter_nodes()) == [ + ("c", _FrozenNode(frozenset(), frozenset(), backend_candidates=None)), + ] + def test_produce_split_locations_simple(self): """Simple produce_split_locations use case: no need for splits""" flat = {