diff --git a/src/graph_scheduler/condition.py b/src/graph_scheduler/condition.py index 3139a93c..6ebd1eac 100644 --- a/src/graph_scheduler/condition.py +++ b/src/graph_scheduler/condition.py @@ -300,7 +300,7 @@ def converge(node, thresh): from graph_scheduler import _unit_registry from graph_scheduler.time import TimeScale -from graph_scheduler.utilities import call_with_pruned_args, clone_graph, get_descendants, get_receivers, gs_logging_formatter +from graph_scheduler.utilities import call_with_pruned_args, clone_graph, get_ancestors, get_descendants, get_receivers, gs_logging_formatter _additional__all__ = ['Action', 'ConditionError', 'ConditionSet'] @@ -2417,6 +2417,7 @@ def __init__( debug: bool = False, ignore_conflicts: bool = False, remove_descendant_senders: bool = True, + remove_ancestor_receivers: bool = True, **kwargs, ): subject_senders = self._handle_subject_arg( @@ -2437,6 +2438,7 @@ def __init__( debug=debug, ignore_conflicts=ignore_conflicts, remove_descendant_senders=remove_descendant_senders, + remove_ancestor_receivers=remove_ancestor_receivers, **kwargs ) @@ -2612,13 +2614,15 @@ def _manipulate_graph( graph, # [self.owner], self.nodes, - self.remove_descendant_senders, + remove_descendants=self.remove_descendant_senders, ) logger.debug(f'Apply owner_receivers ({self.owner_receivers}) to {self.owner}') new_receiver_nodes[self.owner] = self._process_action( all_receivers_old[self.owner], set.union(*[all_receivers_old[n] for n in self.nodes]), self.owner_receivers, + graph=graph, + remove_ancestors=self.remove_ancestor_receivers, ) for n in self.nodes: @@ -2633,13 +2637,15 @@ def _manipulate_graph( graph, # self.nodes, [self.owner], - self.remove_descendant_senders, + remove_descendants=self.remove_descendant_senders, ) logger.debug(f'Apply subject_receivers ({subject_receivers_n}) to {n}') new_receiver_nodes[n] = self._process_action( all_receivers_old[n], all_receivers_old[self.owner], subject_receivers_n, + graph=graph, + remove_ancestors=self.remove_ancestor_receivers, ) for receiver in new_sender_nodes: @@ -2719,6 +2725,8 @@ def _apply_action_to_edge_sets( ) return res + # TODO: do desc/ances into separate methods/just one + def _process_action( self, source_neighbors: typing.Set, @@ -2727,13 +2735,16 @@ def _process_action( graph=None, comparison_nodes=None, remove_descendants: bool = False, + remove_ancestors: bool = False, ) -> typing.Set: result = self._apply_action_to_edge_sets( source_neighbors, comparison_neighbors, action, ) + if remove_descendants: + assert graph is not None descendants = get_descendants(graph) logger.debug(descendants) @@ -2741,7 +2752,8 @@ def _process_action( # ipdb.set_trace() descendants_to_ignore = set() for k, v in descendants.items(): - if k is self.owner or k in self.nodes: + # if k is self.owner or k in self.nodes: + if k in self.nodes: descendants_to_ignore = descendants_to_ignore.union(v) modified_result = result - descendants_to_ignore @@ -2751,6 +2763,26 @@ def _process_action( ) result = modified_result + if remove_ancestors: + assert graph is not None + ancestors = get_ancestors(graph) + logger.debug(ancestors) + + # import ipdb + # ipdb.set_trace() + ancestors_to_ignore = set() + for k, v in ancestors.items(): + # if k is self.owner or k in self.nodes: + if k in self.nodes: + ancestors_to_ignore = ancestors_to_ignore.union(v) + + modified_result = result - ancestors_to_ignore + if modified_result != result: + logger.debug( + f'Removing ancestors {ancestors_to_ignore} from {result} giving {modified_result}' + ) + result = modified_result + return result @staticmethod @@ -2946,7 +2978,7 @@ def __init__( *nodes, owner_senders: typing.Union[Action, str] = Action.REPLACE, owner_receivers: typing.Union[Action, str] = Action.MERGE, - subject_senders: typing.Union[Action, str, dict] = Action.REPLACE, + subject_senders: typing.Union[Action, str, dict] = Action.MERGE, subject_receivers: typing.Union[Action, str, dict] = Action.KEEP, reconnect_non_subject_receivers: bool = True, ): diff --git a/src/graph_scheduler/utilities.py b/src/graph_scheduler/utilities.py index d6134630..dbaceeb3 100644 --- a/src/graph_scheduler/utilities.py +++ b/src/graph_scheduler/utilities.py @@ -156,3 +156,19 @@ def cached_descendants(g): return {node: nx.descendants(nx_graph, node) for node in graph} return cached_descendants(frozen_graph(graph)) + + +def get_ancestors(graph: typing.Dict) -> typing.Dict: + """ + Returns a dict containing the ancestors of each node in dependency + dictionary **graph** + + Args: + graph (dict): a dependency dictionary + """ + @functools.lru_cache() + def cached_ancestors(g): + nx_graph = nx.DiGraph(graph).reverse() + return {node: nx.ancestors(nx_graph, node) for node in graph} + + return cached_ancestors(frozen_graph(graph)) diff --git a/tests/scheduling/test_condition.py b/tests/scheduling/test_condition.py index 9dc5f2f1..13ad4e1b 100644 --- a/tests/scheduling/test_condition.py +++ b/tests/scheduling/test_condition.py @@ -2134,10 +2134,9 @@ def _single_condition_test_helper( 'five_node_hub', 'D', ['C'], None, {'A': set(), 'B': set(), 'C': {'A', 'B', 'D'}, 'D': {'A', 'B'}, 'E': {'C'}}, ), - # creates cycle due to subject_senders default MERGE ( 'five_node_hub', 'D', ['A'], None, - {'A': {'D'}, 'B': set(), 'C': {'A', 'B'}, 'D': {'C'}, 'E': {'C'}} + {'A': {'D'}, 'B': set(), 'C': {'A', 'B'}, 'D': set(), 'E': {'C'}} ), ( 'five_node_hub', 'C', ['D'], r'.*C is already before D.*Condition is ignored.', @@ -2160,7 +2159,7 @@ def _single_condition_test_helper( 'B': set(), 'C': {'A', 'B', 'E'}, 'D': {'A', 'C'}, - 'E': {'A', 'B', 'H'}, + 'E': {'A', 'B'}, 'F': {'C', 'D', 'E'}, 'G': {'E'}, 'H': {'G'},