From 573dd4c0c6451c4a14c9bd169d22c40cc769e263 Mon Sep 17 00:00:00 2001 From: Katherine Mantel Date: Mon, 23 Oct 2023 22:07:12 +0000 Subject: [PATCH] Scheduler: update to support GraphStructureCondition --- src/graph_scheduler/scheduler.py | 194 ++++++++++++++++++++++------- tests/scheduling/test_scheduler.py | 57 +++++++++ 2 files changed, 208 insertions(+), 43 deletions(-) diff --git a/src/graph_scheduler/scheduler.py b/src/graph_scheduler/scheduler.py index 5da9975d..ec4652a8 100644 --- a/src/graph_scheduler/scheduler.py +++ b/src/graph_scheduler/scheduler.py @@ -301,8 +301,9 @@ import datetime import enum import fractions +import functools import logging -from typing import Union +from typing import Dict, Hashable, Union import networkx as nx import numpy as np @@ -311,11 +312,11 @@ from graph_scheduler import _unit_registry from graph_scheduler.condition import ( - All, AllHaveRun, Always, Condition, ConditionSet, EveryNCalls, Never, - _parse_absolute_unit, _quantity_as_integer, + AddEdgeTo, All, AllHaveRun, Always, Condition, ConditionSet, EveryNCalls, + Never, RemoveEdgeFrom, _parse_absolute_unit, _quantity_as_integer, ) from graph_scheduler.time import _get_pint_unit, Clock, TimeScale -from graph_scheduler.utilities import clone_graph +from graph_scheduler.utilities import cached_graph_function, clone_graph __all__ = [ 'Scheduler', 'SchedulerError', 'SchedulingMode', @@ -342,6 +343,20 @@ class SchedulingMode(enum.Enum): EXACT_TIME = enum.auto() +@cached_graph_function +def _build_consideration_queue(dependency_dict): + return list(toposort(dependency_dict)) + + +def _generate_consideration_queue_indices(consideration_queue): + consideration_queue_indices = {} + for i, cs in enumerate(consideration_queue): + consideration_queue_indices.update({ + n: i for n in cs + }) + return consideration_queue_indices + + class SchedulerError(Exception): def __init__(self, error_value): @@ -427,11 +442,13 @@ def __init__( :param self: :param conditions: (ConditionSet) - a :keyword:`ConditionSet` to be scheduled """ - self.conditions = ConditionSet(conditions) + self.conditions = ConditionSet() + self._last_handled_structural_condition_order = None + + self._dependency_dicts = [] + self._consideration_queues = [] + self._consideration_queue_indices = [] - # the consideration queue is the ordered list of sets of nodes in the graph, by the - # order in which they should be checked to ensure that all parents have a chance to run before their children - self.consideration_queue = [] if termination_conds is None: termination_conds = default_termination_conds.copy() else: @@ -450,28 +467,28 @@ def __init__( ) if isinstance(graph, nx.Graph): - self.dependency_dict = {} + base_dependency_dict = {} for sender, receivers in graph.adj.items(): - if sender not in self.dependency_dict: - self.dependency_dict[sender] = set() + if sender not in base_dependency_dict: + base_dependency_dict[sender] = set() for rec in receivers: - if rec not in self.dependency_dict: - self.dependency_dict[rec] = set() - self.dependency_dict[rec].add(sender) + if rec not in base_dependency_dict: + base_dependency_dict[rec] = set() + base_dependency_dict[rec].add(sender) else: # add empty dependency set for senders that aren't present - self.dependency_dict = { + base_dependency_dict = { **{n: set() for n in set().union(*graph.values())}, **clone_graph(graph) } - self.consideration_queue = list(toposort(self.dependency_dict)) - self.nodes = [] - for consideration_set in self.consideration_queue: - for node in consideration_set: - self.nodes.append(node) + self._push_dependency_dict(base_dependency_dict) - self._generate_consideration_queue_indices() + self.nodes = list(base_dependency_dict.keys()) + + # add conditions after initial graph to deal with structural + if conditions is not None: + self.add_condition_set(conditions) self.default_execution_id = default_execution_id # stores the in order list of self.run's yielded outputs @@ -484,13 +501,6 @@ def __init__( self.date_creation = datetime.datetime.now() self.date_last_run_end = None - def _generate_consideration_queue_indices(self): - self.consideration_queue_indices = {} - for i, cs in enumerate(self.consideration_queue): - self.consideration_queue_indices.update({ - n: i for n in cs - }) - def _init_counts(self, execution_id, base_execution_id=NotImplemented): """ Attributes @@ -597,8 +607,14 @@ def _combine_termination_conditions(self, termination_conds): @staticmethod def _parse_termination_conditions(termination_conds): + err_msg = ( + "Termination conditions must be a dictionary of the form" + " {TimeScale: Condition, ...} and cannot include" + " GraphStructureCondition." + ) + # parse string representation of TimeScale - parsed_conds = {} + parsed_conds = termination_conds delkeys = set() for scale in termination_conds: try: @@ -607,21 +623,26 @@ def _parse_termination_conditions(termination_conds): except (AttributeError, TypeError): pass - termination_conds.update(parsed_conds) - try: - termination_conds = { - k: termination_conds[k] for k in termination_conds + parsed_conds = { + k: parsed_conds[k] + for k in parsed_conds if ( isinstance(k, TimeScale) - and isinstance(termination_conds[k], Condition) + and isinstance(parsed_conds[k], Condition) and k not in delkeys ) } except TypeError: - raise TypeError('termination_conditions must be a dictionary of the form {TimeScale: Condition, ...}') + raise TypeError(err_msg) else: - return termination_conds + invalid_conds = { + k: termination_conds[k] + for k in termination_conds.keys() - parsed_conds.keys() - delkeys + } + if len(invalid_conds) > 0: + raise SchedulerError(f"{err_msg} Invalid: {invalid_conds}") + return parsed_conds def end_environment_sequence(self, execution_id=NotImplemented): """Signals that an `ENVIRONMENT_SEQUENCE` has completed @@ -634,6 +655,12 @@ def end_environment_sequence(self, execution_id=NotImplemented): self._increment_time(TimeScale.ENVIRONMENT_SEQUENCE, execution_id) + def add_graph_edge(self, sender, receiver): + self.add_condition(sender, AddEdgeTo(receiver)) + + def remove_graph_edge(self, sender, receiver): + self.add_condition(receiver, RemoveEdgeFrom(sender)) + ################################################################################ # Wrapper methods # to allow the user to ignore the ConditionSet internals @@ -657,6 +684,12 @@ def add_condition(self, owner, condition): condition : Condition specifies the Condition, associated with the **owner** to be added to the ConditionSet. """ + + # TODO: validation + # - check structural condition not added during exact time mode + # - check owner and subject nodes in scheduler nodes (all conditions?? if not, absolutely for graph structure conds) + # TODO: + # - set a flag indicating the graph structure needs to be rebuilt if GSC added self.conditions.add_condition(owner, condition) def add_condition_set(self, conditions): @@ -684,6 +717,9 @@ def add_condition_set(self, conditions): # to provide the user with info if they do something odd ################################################################################ def _validate_run_state(self): + # TODO: + # - check if graph is acyclic, if not, see if any graph structure conditions did it + # - check if run previously using a different effective graph and send a warning (?) self._validate_conditions() def _validate_conditions(self): @@ -727,11 +763,6 @@ def run( """ self._validate_run_state() - if self.mode is SchedulingMode.EXACT_TIME: - effective_consideration_queue = [set(self.nodes)] - else: - effective_consideration_queue = self.consideration_queue - if termination_conds is None: termination_conds = self.termination_conds else: @@ -781,14 +812,14 @@ def run( cur_index_consideration_queue = 0 while ( - cur_index_consideration_queue < len(effective_consideration_queue) + cur_index_consideration_queue < len(self.consideration_queue) and not termination_conds[TimeScale.ENVIRONMENT_STATE_UPDATE].is_satisfied(**is_satisfied_kwargs) and not termination_conds[TimeScale.ENVIRONMENT_SEQUENCE].is_satisfied(**is_satisfied_kwargs) ): # all nodes to be added during this consideration set execution cur_consideration_set_execution_exec = set() # the current "layer/group" of nodes that MIGHT be added during this consideration set execution - cur_consideration_set = effective_consideration_queue[cur_index_consideration_queue] + cur_consideration_set = self.consideration_queue[cur_index_consideration_queue] try: iter(cur_consideration_set) @@ -942,3 +973,80 @@ def termination_conds(self, termination_conds): @property def _in_exact_time_mode(self): return self.mode is SchedulingMode.EXACT_TIME or len(self.consideration_queue) == 1 + + def _handle_modified_structural_conditions(self): + if self._last_handled_structural_condition_order != self.conditions.structural_condition_order: + common_index = -1 + for i, cond in enumerate(self.conditions.structural_condition_order): + try: + if self._last_handled_structural_condition_order[i] == cond: + common_index = i + else: + break + except (IndexError, TypeError): + break + + # remove scheduler dependency dicts down to the common + # structural condition + if self._last_handled_structural_condition_order is not None: + # if anything must be done with the popped structures + # and cond, the iteration order should be reversed + for cond in self._last_handled_structural_condition_order[common_index + 1:]: + self._pop_dependency_dict() + + # add dependency dicts for new structural conditions + cur_dependency_dict = self._dependency_dicts[-1] + for cond in self.conditions.structural_condition_order[common_index + 1:]: + cur_dependency_dict = cond.modify_graph(cur_dependency_dict) + self._push_dependency_dict(cur_dependency_dict) + + self._last_handled_structural_condition_order = copy.copy(self.conditions.structural_condition_order) + + def _push_dependency_dict(self, dependency_dict): + self._dependency_dicts.append(dependency_dict) + + consideration_queue = _build_consideration_queue(dependency_dict) + self._consideration_queues.append(consideration_queue) + self._consideration_queue_indices.append( + _generate_consideration_queue_indices(consideration_queue) + ) + + def _pop_dependency_dict(self): + return ( + self._dependency_dicts.pop(), + self._consideration_queues.pop(), + self._consideration_queue_indices.pop(), + ) + + @property + def dependency_dict(self): + self._handle_modified_structural_conditions() + return self._dependency_dicts[-1] + + @property + def consideration_queue(self): + """ + The ordered list of sets of nodes in the graph, by the order in + which they will be checked to ensure that all senders have a + chance to run before their receivers + """ + self._handle_modified_structural_conditions() + + if self.mode is SchedulingMode.EXACT_TIME: + return [set(self.dependency_dict.keys())] + else: + return self._consideration_queues[-1] + + @property + def consideration_queue_indices(self) -> Dict[Hashable, int]: + """ + A dictionary mapping the graph's nodes to their position in + the original consideration queue. This is the same as the + consideration queue when not using + SchedulingMode.EXACT_TIME. + + Returns: + Dict[Hashable, int] + """ + self._handle_modified_structural_conditions() + return self._consideration_queue_indices[-1] diff --git a/tests/scheduling/test_scheduler.py b/tests/scheduling/test_scheduler.py index 993dce34..a4e32ee2 100644 --- a/tests/scheduling/test_scheduler.py +++ b/tests/scheduling/test_scheduler.py @@ -300,6 +300,63 @@ def test_delete_counts(self): assert sched.execution_list[eid_delete] == del_run_1 assert sched.execution_list[eid_repeat] == repeat_run_2 + repeat_run_2 + def test_add_structural_conditions(self): + initial_graph = pytest.helpers.gen_linear_graph('A', 'B', 'C', 'D', 'E') + initial_conds = {'A': gs.AddEdgeTo('C')} + scheduler = gs.Scheduler(initial_graph, initial_conds) + + assert scheduler.dependency_dict == { + **initial_graph, + **{'C': {'A', 'B'}}, + } + assert len(scheduler._dependency_dicts) == 2 + assert scheduler._dependency_dicts[0] == initial_graph + + addl_conditions = [ + ('B', gs.AddEdgeTo('D')), + ('B', gs.AddEdgeTo('E')), + ('C', gs.AddEdgeTo('E')), + ('E', gs.RemoveEdgeFrom('B')), + ('D', gs.RemoveEdgeFrom('B')), + ] + + for owner, cond in addl_conditions: + scheduler.add_condition(owner, cond) + + assert scheduler.dependency_dict == { + 'A': set(), + 'B': {'A'}, + 'C': {'A', 'B'}, + 'D': {'C'}, + 'E': {'C', 'D'}, + } + assert scheduler._last_handled_structural_condition_order == ( + [initial_conds['A']] + [c[1] for c in addl_conditions] + ) + + # take only the first three elements in addl_conditions + addl_conds_sub_idx = 3 + scheduler.conditions = gs.ConditionSet({ + **{ + k: [ + addl_conditions[i][1] for i in range(addl_conds_sub_idx) + if addl_conditions[i][0] == k + ] + for k in initial_graph + }, + 'A': initial_conds['A'], + }) + assert scheduler.dependency_dict == { + 'A': set(), + 'B': {'A'}, + 'C': {'A', 'B'}, + 'D': {'B', 'C'}, + 'E': {'B', 'C', 'D'}, + } + assert scheduler._last_handled_structural_condition_order == ( + [initial_conds['A']] + [c[1] for c in addl_conditions[:addl_conds_sub_idx]] + ) + @pytest.mark.psyneulink class TestLinear: