From 60405a9737020f402c7e165c17240318d7b863f4 Mon Sep 17 00:00:00 2001 From: Katherine Mantel Date: Mon, 16 Oct 2023 23:55:25 +0000 Subject: [PATCH] WIP: gsc add remove descendant --- src/graph_scheduler/condition.py | 54 +++++++++++++++++++++++++++++--- src/graph_scheduler/utilities.py | 20 +++++++++++- 2 files changed, 68 insertions(+), 6 deletions(-) diff --git a/src/graph_scheduler/condition.py b/src/graph_scheduler/condition.py index a1b42e5a..3139a93c 100644 --- a/src/graph_scheduler/condition.py +++ b/src/graph_scheduler/condition.py @@ -300,7 +300,7 @@ def converge(node, thresh): from graph_scheduler import _unit_registry from graph_scheduler.time import TimeScale -from graph_scheduler.utilities import call_with_pruned_args, clone_graph, get_receivers, gs_logging_formatter +from graph_scheduler.utilities import call_with_pruned_args, clone_graph, get_descendants, get_receivers, gs_logging_formatter _additional__all__ = ['Action', 'ConditionError', 'ConditionSet'] @@ -2416,6 +2416,7 @@ def __init__( remove_new_self_referential_edges: bool = True, debug: bool = False, ignore_conflicts: bool = False, + remove_descendant_senders: bool = True, **kwargs, ): subject_senders = self._handle_subject_arg( @@ -2435,6 +2436,7 @@ def __init__( remove_new_self_referential_edges=remove_new_self_referential_edges, debug=debug, ignore_conflicts=ignore_conflicts, + remove_descendant_senders=remove_descendant_senders, **kwargs ) @@ -2603,13 +2605,17 @@ def _manipulate_graph( result_graph = clone_graph(graph) logger.debug(f'Apply owner_senders ({self.owner_senders}) to {self.owner}') - new_sender_nodes[self.owner] = self._apply_action_to_edge_sets( + new_sender_nodes[self.owner] = self._process_action( old_owner_sender_nodes, set.union(*old_subject_sender_nodes.values()), self.owner_senders, + graph, + # [self.owner], + self.nodes, + self.remove_descendant_senders, ) logger.debug(f'Apply owner_receivers ({self.owner_receivers}) to {self.owner}') - new_receiver_nodes[self.owner] = self._apply_action_to_edge_sets( + new_receiver_nodes[self.owner] = self._process_action( all_receivers_old[self.owner], set.union(*[all_receivers_old[n] for n in self.nodes]), self.owner_receivers, @@ -2620,13 +2626,17 @@ def _manipulate_graph( subject_receivers_n = self._get_subject_action(self.subject_receivers, n) logger.debug(f'Apply subject_senders ({subject_senders_n}) to {n}') - new_sender_nodes[n] = self._apply_action_to_edge_sets( + new_sender_nodes[n] = self._process_action( old_subject_sender_nodes[n], old_owner_sender_nodes, subject_senders_n, + graph, + # self.nodes, + [self.owner], + self.remove_descendant_senders, ) logger.debug(f'Apply subject_receivers ({subject_receivers_n}) to {n}') - new_receiver_nodes[n] = self._apply_action_to_edge_sets( + new_receiver_nodes[n] = self._process_action( all_receivers_old[n], all_receivers_old[self.owner], subject_receivers_n, @@ -2709,6 +2719,40 @@ def _apply_action_to_edge_sets( ) return res + def _process_action( + self, + source_neighbors: typing.Set, + comparison_neighbors: typing.Set, + action: Action, + graph=None, + comparison_nodes=None, + remove_descendants: bool = False, + ) -> typing.Set: + result = self._apply_action_to_edge_sets( + source_neighbors, + comparison_neighbors, + action, + ) + if remove_descendants: + descendants = get_descendants(graph) + logger.debug(descendants) + + # import ipdb + # ipdb.set_trace() + descendants_to_ignore = set() + for k, v in descendants.items(): + if k is self.owner or k in self.nodes: + descendants_to_ignore = descendants_to_ignore.union(v) + + modified_result = result - descendants_to_ignore + if modified_result != result: + logger.debug( + f'Removing descendants {descendants_to_ignore} from {result} giving {modified_result}' + ) + result = modified_result + + return result + @staticmethod def _parse_action_arg(action): try: diff --git a/src/graph_scheduler/utilities.py b/src/graph_scheduler/utilities.py index 67d2a4d7..d6134630 100644 --- a/src/graph_scheduler/utilities.py +++ b/src/graph_scheduler/utilities.py @@ -1,10 +1,12 @@ import collections +import functools import inspect import logging -import networkx as nx import typing import weakref +import networkx as nx + __all__ = ['clone_graph', 'get_receivers', 'output_graph_png'] _unused_args_sig_cache = weakref.WeakKeyDictionary() @@ -138,3 +140,19 @@ def output_graph_png( pd = nx.drawing.nx_pydot.to_pydot(gr) pd.write_png(filename) print(f'Wrote png to {filename}') + + +def get_descendants(graph: typing.Dict) -> typing.Dict: + """ + Returns a dict containing the descendants of each node in dependency + dictionary **graph** + + Args: + graph (dict): a dependency dictionary + """ + @functools.lru_cache() + def cached_descendants(g): + nx_graph = nx.DiGraph(graph).reverse() + return {node: nx.descendants(nx_graph, node) for node in graph} + + return cached_descendants(frozen_graph(graph))