Skip to content

Commit

Permalink
WIP: gsc add remove descendant
Browse files Browse the repository at this point in the history
  • Loading branch information
kmantel committed Oct 16, 2023
1 parent a8baa3d commit 60405a9
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 6 deletions.
54 changes: 49 additions & 5 deletions src/graph_scheduler/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def converge(node, thresh):

from graph_scheduler import _unit_registry
from graph_scheduler.time import TimeScale
from graph_scheduler.utilities import call_with_pruned_args, clone_graph, get_receivers, gs_logging_formatter
from graph_scheduler.utilities import call_with_pruned_args, clone_graph, get_descendants, get_receivers, gs_logging_formatter


_additional__all__ = ['Action', 'ConditionError', 'ConditionSet']
Expand Down Expand Up @@ -2416,6 +2416,7 @@ def __init__(
remove_new_self_referential_edges: bool = True,
debug: bool = False,
ignore_conflicts: bool = False,
remove_descendant_senders: bool = True,
**kwargs,
):
subject_senders = self._handle_subject_arg(
Expand All @@ -2435,6 +2436,7 @@ def __init__(
remove_new_self_referential_edges=remove_new_self_referential_edges,
debug=debug,
ignore_conflicts=ignore_conflicts,
remove_descendant_senders=remove_descendant_senders,
**kwargs
)

Expand Down Expand Up @@ -2603,13 +2605,17 @@ def _manipulate_graph(
result_graph = clone_graph(graph)

logger.debug(f'Apply owner_senders ({self.owner_senders}) to {self.owner}')
new_sender_nodes[self.owner] = self._apply_action_to_edge_sets(
new_sender_nodes[self.owner] = self._process_action(
old_owner_sender_nodes,
set.union(*old_subject_sender_nodes.values()),
self.owner_senders,
graph,
# [self.owner],
self.nodes,
self.remove_descendant_senders,
)
logger.debug(f'Apply owner_receivers ({self.owner_receivers}) to {self.owner}')
new_receiver_nodes[self.owner] = self._apply_action_to_edge_sets(
new_receiver_nodes[self.owner] = self._process_action(
all_receivers_old[self.owner],
set.union(*[all_receivers_old[n] for n in self.nodes]),
self.owner_receivers,
Expand All @@ -2620,13 +2626,17 @@ def _manipulate_graph(
subject_receivers_n = self._get_subject_action(self.subject_receivers, n)

logger.debug(f'Apply subject_senders ({subject_senders_n}) to {n}')
new_sender_nodes[n] = self._apply_action_to_edge_sets(
new_sender_nodes[n] = self._process_action(
old_subject_sender_nodes[n],
old_owner_sender_nodes,
subject_senders_n,
graph,
# self.nodes,
[self.owner],
self.remove_descendant_senders,
)
logger.debug(f'Apply subject_receivers ({subject_receivers_n}) to {n}')
new_receiver_nodes[n] = self._apply_action_to_edge_sets(
new_receiver_nodes[n] = self._process_action(
all_receivers_old[n],
all_receivers_old[self.owner],
subject_receivers_n,
Expand Down Expand Up @@ -2709,6 +2719,40 @@ def _apply_action_to_edge_sets(
)
return res

def _process_action(
self,
source_neighbors: typing.Set,
comparison_neighbors: typing.Set,
action: Action,
graph=None,
comparison_nodes=None,
remove_descendants: bool = False,
) -> typing.Set:
result = self._apply_action_to_edge_sets(
source_neighbors,
comparison_neighbors,
action,
)
if remove_descendants:
descendants = get_descendants(graph)
logger.debug(descendants)

# import ipdb
# ipdb.set_trace()
descendants_to_ignore = set()
for k, v in descendants.items():
if k is self.owner or k in self.nodes:
descendants_to_ignore = descendants_to_ignore.union(v)

modified_result = result - descendants_to_ignore
if modified_result != result:
logger.debug(
f'Removing descendants {descendants_to_ignore} from {result} giving {modified_result}'
)
result = modified_result

return result

@staticmethod
def _parse_action_arg(action):
try:
Expand Down
20 changes: 19 additions & 1 deletion src/graph_scheduler/utilities.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import collections
import functools
import inspect
import logging
import networkx as nx
import typing
import weakref

import networkx as nx

__all__ = ['clone_graph', 'get_receivers', 'output_graph_png']

_unused_args_sig_cache = weakref.WeakKeyDictionary()
Expand Down Expand Up @@ -138,3 +140,19 @@ def output_graph_png(
pd = nx.drawing.nx_pydot.to_pydot(gr)
pd.write_png(filename)
print(f'Wrote png to {filename}')


def get_descendants(graph: typing.Dict) -> typing.Dict:
"""
Returns a dict containing the descendants of each node in dependency
dictionary **graph**
Args:
graph (dict): a dependency dictionary
"""
@functools.lru_cache()
def cached_descendants(g):
nx_graph = nx.DiGraph(graph).reverse()
return {node: nx.descendants(nx_graph, node) for node in graph}

return cached_descendants(frozen_graph(graph))

0 comments on commit 60405a9

Please sign in to comment.