From 7276f0b1586c9bc346dcc07129b09202c8cc49d6 Mon Sep 17 00:00:00 2001 From: Stefaan Lippens Date: Mon, 16 Sep 2024 12:19:36 +0200 Subject: [PATCH] fixup! fixup! Issue #150 WIP more advanced pg splitting --- .../partitionedjobs/crossbackend.py | 5 +- tests/partitionedjobs/test_crossbackend.py | 51 ++++++++++++++++++- 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/src/openeo_aggregator/partitionedjobs/crossbackend.py b/src/openeo_aggregator/partitionedjobs/crossbackend.py index cddba4f..f1229fc 100644 --- a/src/openeo_aggregator/partitionedjobs/crossbackend.py +++ b/src/openeo_aggregator/partitionedjobs/crossbackend.py @@ -555,7 +555,10 @@ def get_backend_candidates(self, node_id: NodeId) -> Union[frozenset[BackendId], # TODO: cache intermediate sets? (Only when caching is safe: e.g. wrapped graph is immutable/not manipulated) upstream_candidates = (self.get_backend_candidates(n) for n in self.node(node_id).depends_on) upstream_candidates = [c for c in upstream_candidates if c is not None] - return functools.reduce(lambda a, b: a.intersection(b), upstream_candidates) + if upstream_candidates: + return functools.reduce(lambda a, b: a.intersection(b), upstream_candidates) + else: + return None else: return None diff --git a/tests/partitionedjobs/test_crossbackend.py b/tests/partitionedjobs/test_crossbackend.py index 6c647b5..0d0c935 100644 --- a/tests/partitionedjobs/test_crossbackend.py +++ b/tests/partitionedjobs/test_crossbackend.py @@ -503,6 +503,16 @@ def test_get_backend_candidates_basic(self): assert graph.get_backend_candidates("c") == {"b2"} assert graph.get_backend_candidates("d") == set() + def test_get_backend_candidates_none(self): + graph = _FrozenGraph.from_edges( + [("a", "b"), ("b", "d"), ("c", "d")], + backend_candidates_map={}, + ) + assert graph.get_backend_candidates("a") is None + assert graph.get_backend_candidates("b") is None + assert graph.get_backend_candidates("c") is None + assert graph.get_backend_candidates("d") is None + def test_get_backend_candidates_intersection(self): graph = _FrozenGraph.from_edges( [("a", "d"), ("b", "d"), ("b", "e"), ("c", "e"), ("d", "f"), ("e", "f")], @@ -614,7 +624,7 @@ def test_produce_split_locations_simple(self): graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={"lc1": "b1"}) assert list(graph.produce_split_locations()) == [[]] - def test_produce_split_locations_basic(self): + def test_produce_split_locations_merge_basic(self): """ Basic produce_split_locations use case: two load collections on different backends and a merge @@ -629,3 +639,42 @@ def test_produce_split_locations_basic(self): } graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"]}) assert sorted(graph.produce_split_locations()) == [["lc1"], ["lc2"]] + + def test_produce_split_locations_merge_longer(self): + flat = { + "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, + "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "bands2": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc2"}, "bands": ["B02"]}}, + "merge1": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "bands2"}}, + }, + } + graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"]}) + assert sorted(graph.produce_split_locations(limit=2)) == [["bands1"], ["bands2"]] + assert list(graph.produce_split_locations(limit=4)) == dirty_equals.IsOneOf( + [["bands1"], ["bands2"], ["lc1"], ["lc2"]], + [["bands2"], ["bands1"], ["lc2"], ["lc1"]], + ) + + def test_produce_split_locations_merge_longer_triangle(self): + flat = { + "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, + "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}}, + "mask1": { + "process_id": "mask", + "arguments": {"data": {"from_node": "bands1"}, "mask": {"from_node": "lc1"}}, + }, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "bands2": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc2"}, "bands": ["B02"]}}, + "merge1": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "mask1"}, "cube2": {"from_node": "bands2"}}, + }, + } + graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"]}) + assert list(graph.produce_split_locations(limit=4)) == dirty_equals.IsOneOf( + [["mask1"], ["bands2"], ["lc1"], ["lc2"]], + [["bands2"], ["mask1"], ["lc2"], ["lc1"]], + )