Skip to content

Commit

Permalink
mg big gsc sep methods of non sub recs and self ref edges
Browse files Browse the repository at this point in the history
  • Loading branch information
kmantel committed Oct 12, 2023
1 parent 16b70ff commit c09eb31
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 27 deletions.
64 changes: 39 additions & 25 deletions src/graph_scheduler/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2366,15 +2366,15 @@ def manipulate_graph(self, graph):
f"Root logger level changed to {logging.getLevelName(new_level)}"
)

res = self._manipulate_graph(clone_graph(graph))
result_graph = self._manipulate_graph(clone_graph(graph))

if self.debug:
logger.info(
f"Restoring root logger level to {logging.getLevelName(orig_log_level)}"
)
root_logger.setLevel(orig_log_level)

return res
return result_graph

@abc.abstractmethod
def _manipulate_graph(self, graph):
Expand Down Expand Up @@ -2548,6 +2548,33 @@ def _get_edge_conflicts(self, sender, receiver, new_sender_nodes, new_receiver_n

return conflict_strings

def _preprocess_graph(self, graph):
return graph

def _handle_reconnect_non_subject_receivers(self, graph, base_graph):
if self.reconnect_non_subject_receivers:
orig_owner_sender_nodes = base_graph[self.owner]
orig_receivers_old = get_receivers(base_graph)

for sender in orig_owner_sender_nodes:
for receiver in orig_receivers_old[self.owner]:
# only add as receiver
if receiver not in self.nodes:
graph[receiver].add(sender)
logger.debug(
f'Reconnecting {receiver} as receiver of {sender}'
)
return graph

def _handle_remove_new_self_referential_edges(self, graph, base_graph):
if self.remove_new_self_referential_edges:
for node in graph:
if node not in base_graph[node]:
if node in graph[node]:
graph[node].remove(node)
logger.debug(f'New self-referential edge removed for {node}')
return graph

# TODO: document:
# - owner_senders compares owner's original senders vs the union of senders of all subject nodes
# - owner_receivers compares owner's original receivers vs the union of receivers of all subject nodes
Expand All @@ -2556,7 +2583,10 @@ def _manipulate_graph(
self,
graph: typing.Dict[typing.Hashable, typing.Set]
) -> typing.Dict[typing.Hashable, typing.Set]:
unprocessed_graph = clone_graph(graph)
graph = self._preprocess_graph(graph)
graph = super()._manipulate_graph(graph)

all_receivers_old = get_receivers(graph)
old_owner_sender_nodes = graph[self.owner]
old_subject_sender_nodes = {n: graph[n] for n in self.nodes}
Expand Down Expand Up @@ -2636,22 +2666,10 @@ def _manipulate_graph(
# subject_senders: typing.Union[Action, str, dict] = Action.KEEP,
# subject_receivers: typing.Union[Action, str, dict] = Action.KEEP,
# because in this case, E should maintain its dependency on B
if self.reconnect_non_subject_receivers:
for sender in old_owner_sender_nodes:
# ONLY add new edges if self.owner is no longer a receiver of sender
if sender not in result_graph[self.owner]:
for receiver in all_receivers_old[self.owner]:
# only add as receiver
if receiver not in self.nodes:
result_graph[receiver].add(sender)

# TODO: move this to a step in manipulate_graph also?
if self.remove_new_self_referential_edges:
for node in result_graph:
if node not in graph[node]:
result_graph[node].discard(node)
result_graph = self._handle_reconnect_non_subject_receivers(result_graph, unprocessed_graph)
result_graph = self._handle_remove_new_self_referential_edges(result_graph, graph)

logger.debug(f'graph after {self} manipulate_graph:', result_graph)
logger.debug(f'graph after {self} _manipulate_graph:', result_graph)
return result_graph

def _apply_action_to_edge_sets(
Expand Down Expand Up @@ -2782,17 +2800,13 @@ def _validate_graph(self, graph: typing.Dict[typing.Hashable, typing.Set]):
# the changes to move reconnect_non_subject_receivers into outer
# manipulate_graph method
def _preprocess_graph(self, graph):
return graph

@abc.abstractmethod
def _manipulate_graph(self, graph):
for n in self.nodes:
graph[n].discard(self.owner)
graph[self.owner].discard(n)
return super()._manipulate_graph(graph)
return graph


class BeforeNodes(GSCWithActions, GSCReposition):
class BeforeNodes(GSCReposition, GSCWithActions):
_already_valid_message = 'before'

# TODO: choose preferred default for subject_receivers between KEEP and MERGE
Expand Down Expand Up @@ -2831,7 +2845,7 @@ class BeforeNode(BeforeNodes, GSCSingleNode):
pass


class WithNode(GSCWithActions, GSCSingleNode, GSCReposition):
class WithNode(GSCReposition, GSCWithActions, GSCSingleNode):
_already_valid_message = 'with'

# owner senders only valid as REPLACE or MERGE
Expand Down Expand Up @@ -2873,7 +2887,7 @@ def _manipulate_graph(self, graph):
return graph


class AfterNodes(GSCWithActions, GSCReposition):
class AfterNodes(GSCReposition, GSCWithActions):
_already_valid_message = 'after'

def __init__(
Expand Down
4 changes: 2 additions & 2 deletions tests/scheduling/test_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2089,7 +2089,7 @@ def custom_gsc_func_1(graph):
]
)
def test_CustomGraphStructureCondition(self, func, nodes, graph, expected_graph):
cond = gs.CustomGraphStructureCondition(func, *nodes)
cond = gs.CustomGraphStructureCondition(func, *nodes, debug=True)
cond.owner = 'A'
assert cond.manipulate_graph(graph) == expected_graph

Expand Down Expand Up @@ -2144,7 +2144,7 @@ def _single_condition_test_helper(
),
(
'five_node_hub', 'C', ['D', 'B'], r'.*C is already before D.*(?<!ignored.)$',
{'A': set(), 'B': {'C'}, 'C': set(), 'D': {'C'}, 'E': {'A', 'B'}},
{'A': set(), 'B': {'C'}, 'C': {'A'}, 'D': {'C'}, 'E': {'A', 'B', 'C'}},
),
(
'five_node_hub', 'C', ['D', 'E'], r'.*C is already before D,E.*Condition is ignored.',
Expand Down

0 comments on commit c09eb31

Please sign in to comment.