diff --git a/src/graph_scheduler/condition.py b/src/graph_scheduler/condition.py index fdc442a2..29611dc1 100644 --- a/src/graph_scheduler/condition.py +++ b/src/graph_scheduler/condition.py @@ -2366,7 +2366,7 @@ def manipulate_graph(self, graph): f"Root logger level changed to {logging.getLevelName(new_level)}" ) - res = self._manipulate_graph(clone_graph(graph)) + result_graph = self._manipulate_graph(clone_graph(graph)) if self.debug: logger.info( @@ -2374,7 +2374,7 @@ def manipulate_graph(self, graph): ) root_logger.setLevel(orig_log_level) - return res + return result_graph @abc.abstractmethod def _manipulate_graph(self, graph): @@ -2548,6 +2548,33 @@ def _get_edge_conflicts(self, sender, receiver, new_sender_nodes, new_receiver_n return conflict_strings + def _preprocess_graph(self, graph): + return graph + + def _handle_reconnect_non_subject_receivers(self, graph, base_graph): + if self.reconnect_non_subject_receivers: + orig_owner_sender_nodes = base_graph[self.owner] + orig_receivers_old = get_receivers(base_graph) + + for sender in orig_owner_sender_nodes: + for receiver in orig_receivers_old[self.owner]: + # only add as receiver + if receiver not in self.nodes: + graph[receiver].add(sender) + logger.debug( + f'Reconnecting {receiver} as receiver of {sender}' + ) + return graph + + def _handle_remove_new_self_referential_edges(self, graph, base_graph): + if self.remove_new_self_referential_edges: + for node in graph: + if node not in base_graph[node]: + if node in graph[node]: + graph[node].remove(node) + logger.debug(f'New self-referential edge removed for {node}') + return graph + # TODO: document: # - owner_senders compares owner's original senders vs the union of senders of all subject nodes # - owner_receivers compares owner's original receivers vs the union of receivers of all subject nodes @@ -2556,7 +2583,10 @@ def _manipulate_graph( self, graph: typing.Dict[typing.Hashable, typing.Set] ) -> typing.Dict[typing.Hashable, typing.Set]: + unprocessed_graph = clone_graph(graph) + graph = self._preprocess_graph(graph) graph = super()._manipulate_graph(graph) + all_receivers_old = get_receivers(graph) old_owner_sender_nodes = graph[self.owner] old_subject_sender_nodes = {n: graph[n] for n in self.nodes} @@ -2636,22 +2666,10 @@ def _manipulate_graph( # subject_senders: typing.Union[Action, str, dict] = Action.KEEP, # subject_receivers: typing.Union[Action, str, dict] = Action.KEEP, # because in this case, E should maintain its dependency on B - if self.reconnect_non_subject_receivers: - for sender in old_owner_sender_nodes: - # ONLY add new edges if self.owner is no longer a receiver of sender - if sender not in result_graph[self.owner]: - for receiver in all_receivers_old[self.owner]: - # only add as receiver - if receiver not in self.nodes: - result_graph[receiver].add(sender) - - # TODO: move this to a step in manipulate_graph also? - if self.remove_new_self_referential_edges: - for node in result_graph: - if node not in graph[node]: - result_graph[node].discard(node) + result_graph = self._handle_reconnect_non_subject_receivers(result_graph, unprocessed_graph) + result_graph = self._handle_remove_new_self_referential_edges(result_graph, graph) - logger.debug(f'graph after {self} manipulate_graph:', result_graph) + logger.debug(f'graph after {self} _manipulate_graph:', result_graph) return result_graph def _apply_action_to_edge_sets( @@ -2782,17 +2800,13 @@ def _validate_graph(self, graph: typing.Dict[typing.Hashable, typing.Set]): # the changes to move reconnect_non_subject_receivers into outer # manipulate_graph method def _preprocess_graph(self, graph): - return graph - - @abc.abstractmethod - def _manipulate_graph(self, graph): for n in self.nodes: graph[n].discard(self.owner) graph[self.owner].discard(n) - return super()._manipulate_graph(graph) + return graph -class BeforeNodes(GSCWithActions, GSCReposition): +class BeforeNodes(GSCReposition, GSCWithActions): _already_valid_message = 'before' # TODO: choose preferred default for subject_receivers between KEEP and MERGE @@ -2831,7 +2845,7 @@ class BeforeNode(BeforeNodes, GSCSingleNode): pass -class WithNode(GSCWithActions, GSCSingleNode, GSCReposition): +class WithNode(GSCReposition, GSCWithActions, GSCSingleNode): _already_valid_message = 'with' # owner senders only valid as REPLACE or MERGE @@ -2873,7 +2887,7 @@ def _manipulate_graph(self, graph): return graph -class AfterNodes(GSCWithActions, GSCReposition): +class AfterNodes(GSCReposition, GSCWithActions): _already_valid_message = 'after' def __init__( diff --git a/tests/scheduling/test_condition.py b/tests/scheduling/test_condition.py index b19379e5..e2a1441d 100644 --- a/tests/scheduling/test_condition.py +++ b/tests/scheduling/test_condition.py @@ -2089,7 +2089,7 @@ def custom_gsc_func_1(graph): ] ) def test_CustomGraphStructureCondition(self, func, nodes, graph, expected_graph): - cond = gs.CustomGraphStructureCondition(func, *nodes) + cond = gs.CustomGraphStructureCondition(func, *nodes, debug=True) cond.owner = 'A' assert cond.manipulate_graph(graph) == expected_graph @@ -2144,7 +2144,7 @@ def _single_condition_test_helper( ), ( 'five_node_hub', 'C', ['D', 'B'], r'.*C is already before D.*(?