Skip to content

Commit

Permalink
mg gsc add remove ancestor recs
Browse files Browse the repository at this point in the history
  • Loading branch information
kmantel committed Oct 17, 2023
1 parent 60405a9 commit 0580d3e
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 8 deletions.
42 changes: 37 additions & 5 deletions src/graph_scheduler/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def converge(node, thresh):

from graph_scheduler import _unit_registry
from graph_scheduler.time import TimeScale
from graph_scheduler.utilities import call_with_pruned_args, clone_graph, get_descendants, get_receivers, gs_logging_formatter
from graph_scheduler.utilities import call_with_pruned_args, clone_graph, get_ancestors, get_descendants, get_receivers, gs_logging_formatter


_additional__all__ = ['Action', 'ConditionError', 'ConditionSet']
Expand Down Expand Up @@ -2417,6 +2417,7 @@ def __init__(
debug: bool = False,
ignore_conflicts: bool = False,
remove_descendant_senders: bool = True,
remove_ancestor_receivers: bool = True,
**kwargs,
):
subject_senders = self._handle_subject_arg(
Expand All @@ -2437,6 +2438,7 @@ def __init__(
debug=debug,
ignore_conflicts=ignore_conflicts,
remove_descendant_senders=remove_descendant_senders,
remove_ancestor_receivers=remove_ancestor_receivers,
**kwargs
)

Expand Down Expand Up @@ -2612,13 +2614,15 @@ def _manipulate_graph(
graph,
# [self.owner],
self.nodes,
self.remove_descendant_senders,
remove_descendants=self.remove_descendant_senders,
)
logger.debug(f'Apply owner_receivers ({self.owner_receivers}) to {self.owner}')
new_receiver_nodes[self.owner] = self._process_action(
all_receivers_old[self.owner],
set.union(*[all_receivers_old[n] for n in self.nodes]),
self.owner_receivers,
graph=graph,
remove_ancestors=self.remove_ancestor_receivers,
)

for n in self.nodes:
Expand All @@ -2633,13 +2637,15 @@ def _manipulate_graph(
graph,
# self.nodes,
[self.owner],
self.remove_descendant_senders,
remove_descendants=self.remove_descendant_senders,
)
logger.debug(f'Apply subject_receivers ({subject_receivers_n}) to {n}')
new_receiver_nodes[n] = self._process_action(
all_receivers_old[n],
all_receivers_old[self.owner],
subject_receivers_n,
graph=graph,
remove_ancestors=self.remove_ancestor_receivers,
)

for receiver in new_sender_nodes:
Expand Down Expand Up @@ -2719,6 +2725,8 @@ def _apply_action_to_edge_sets(
)
return res

# TODO: do desc/ances into separate methods/just one

def _process_action(
self,
source_neighbors: typing.Set,
Expand All @@ -2727,21 +2735,25 @@ def _process_action(
graph=None,
comparison_nodes=None,
remove_descendants: bool = False,
remove_ancestors: bool = False,
) -> typing.Set:
result = self._apply_action_to_edge_sets(
source_neighbors,
comparison_neighbors,
action,
)

if remove_descendants:
assert graph is not None
descendants = get_descendants(graph)
logger.debug(descendants)

# import ipdb
# ipdb.set_trace()
descendants_to_ignore = set()
for k, v in descendants.items():
if k is self.owner or k in self.nodes:
# if k is self.owner or k in self.nodes:
if k in self.nodes:
descendants_to_ignore = descendants_to_ignore.union(v)

modified_result = result - descendants_to_ignore
Expand All @@ -2751,6 +2763,26 @@ def _process_action(
)
result = modified_result

if remove_ancestors:
assert graph is not None
ancestors = get_ancestors(graph)
logger.debug(ancestors)

# import ipdb
# ipdb.set_trace()
ancestors_to_ignore = set()
for k, v in ancestors.items():
# if k is self.owner or k in self.nodes:
if k in self.nodes:
ancestors_to_ignore = ancestors_to_ignore.union(v)

modified_result = result - ancestors_to_ignore
if modified_result != result:
logger.debug(
f'Removing ancestors {ancestors_to_ignore} from {result} giving {modified_result}'
)
result = modified_result

return result

@staticmethod
Expand Down Expand Up @@ -2946,7 +2978,7 @@ def __init__(
*nodes,
owner_senders: typing.Union[Action, str] = Action.REPLACE,
owner_receivers: typing.Union[Action, str] = Action.MERGE,
subject_senders: typing.Union[Action, str, dict] = Action.REPLACE,
subject_senders: typing.Union[Action, str, dict] = Action.MERGE,
subject_receivers: typing.Union[Action, str, dict] = Action.KEEP,
reconnect_non_subject_receivers: bool = True,
):
Expand Down
16 changes: 16 additions & 0 deletions src/graph_scheduler/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,19 @@ def cached_descendants(g):
return {node: nx.descendants(nx_graph, node) for node in graph}

return cached_descendants(frozen_graph(graph))


def get_ancestors(graph: typing.Dict) -> typing.Dict:
"""
Returns a dict containing the ancestors of each node in dependency
dictionary **graph**
Args:
graph (dict): a dependency dictionary
"""
@functools.lru_cache()
def cached_ancestors(g):
nx_graph = nx.DiGraph(graph).reverse()
return {node: nx.ancestors(nx_graph, node) for node in graph}

return cached_ancestors(frozen_graph(graph))
5 changes: 2 additions & 3 deletions tests/scheduling/test_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2134,10 +2134,9 @@ def _single_condition_test_helper(
'five_node_hub', 'D', ['C'], None,
{'A': set(), 'B': set(), 'C': {'A', 'B', 'D'}, 'D': {'A', 'B'}, 'E': {'C'}},
),
# creates cycle due to subject_senders default MERGE
(
'five_node_hub', 'D', ['A'], None,
{'A': {'D'}, 'B': set(), 'C': {'A', 'B'}, 'D': {'C'}, 'E': {'C'}}
{'A': {'D'}, 'B': set(), 'C': {'A', 'B'}, 'D': set(), 'E': {'C'}}
),
(
'five_node_hub', 'C', ['D'], r'.*C is already before D.*Condition is ignored.',
Expand All @@ -2160,7 +2159,7 @@ def _single_condition_test_helper(
'B': set(),
'C': {'A', 'B', 'E'},
'D': {'A', 'C'},
'E': {'A', 'B', 'H'},
'E': {'A', 'B'},
'F': {'C', 'D', 'E'},
'G': {'E'},
'H': {'G'},
Expand Down

0 comments on commit 0580d3e

Please sign in to comment.