diff --git a/src/graph_scheduler/scheduler.py b/src/graph_scheduler/scheduler.py index a08a3c3..d5168f8 100644 --- a/src/graph_scheduler/scheduler.py +++ b/src/graph_scheduler/scheduler.py @@ -33,6 +33,11 @@ ConditionSets can also be added after the Scheduler has been created, using its `add_condition` and `add_condition_set` methods, respectively. +`Graph structure Conditions ` are +applied to the Scheduler's graph in the order in which they are `added +`. + + .. _Scheduler_Algorithm: Algorithm @@ -41,8 +46,12 @@ .. _Consideration_Set: When a Scheduler is created, it constructs a `consideration_queue`: a list of ``consideration sets`` -that defines the order in which nodes are eligible to be executed. This is based on the dependencies specified in the graph -specification provided in the Scheduler's constructor. Each ``consideration_set`` +that defines the order in which nodes are eligible to be executed. This +is determined by the topological ordering of the `graph +` provided to the `Scheduler's constructor +`, which is then modified by any `graph structure conditions +` that are `added +` to the Scheduler. Each ``consideration_set`` is a set of nodes that are eligible to execute at the same time/`CONSIDERATION_SET_EXECUTION` (i.e., that appear at the same "depth" in a sequence of dependencies, and among which there are no dependencies). The first ``consideration_set`` consists of only origin nodes. The second consists of all nodes @@ -153,11 +162,13 @@ *Termination Conditions* ~~~~~~~~~~~~~~~~~~~~~~~~ -Termination conditions are `Conditions ` that specify when the open-ended units of time - `ENVIRONMENT_STATE_UPDATE +Termination conditions are basic `Conditions ` that specify +when the open-ended units of time - `ENVIRONMENT_STATE_UPDATE ` and `ENVIRONMENT_SEQUENCE` - have ended. By default, the termination condition for an `ENVIRONMENT_STATE_UPDATE ` is `AllHaveRun`, which is satisfied when all nodes have run at least once within the environment state update, and the termination condition for an `ENVIRONMENT_SEQUENCE` is when all of its constituent environment state updates have terminated. - +`Graph structure conditions ` cannot +be used as termination conditions. .. _Scheduler_Absolute_Time: @@ -302,20 +313,25 @@ import enum import fractions import logging -from typing import Union +import warnings +from typing import Dict, Hashable, Iterable, List, Set, Union import networkx as nx import numpy as np import pint -from toposort import toposort +from toposort import CircularDependencyError, toposort from graph_scheduler import _unit_registry from graph_scheduler.condition import ( - All, AllHaveRun, Always, Condition, ConditionSet, EveryNCalls, Never, - _parse_absolute_unit, _quantity_as_integer, + AddEdgeTo, All, AllHaveRun, Always, Condition, ConditionSet, EveryNCalls, + GraphStructureCondition, Never, RemoveEdgeFrom, _parse_absolute_unit, + _quantity_as_integer, typing_condition_base, ) from graph_scheduler.time import _get_pint_unit, Clock, TimeScale -from graph_scheduler.utilities import clone_graph, networkx_digraph_to_dependency_dict +from graph_scheduler.utilities import ( + cached_graph_function, clone_graph, networkx_digraph_to_dependency_dict, + typing_graph_dependency_dict, +) __all__ = [ 'Scheduler', 'SchedulerError', 'SchedulingMode', @@ -342,6 +358,29 @@ class SchedulingMode(enum.Enum): EXACT_TIME = enum.auto() +@cached_graph_function +def _build_consideration_queue( + graph: typing_graph_dependency_dict +) -> List[Set[Hashable]]: + return list(toposort(graph)) + + +def _generate_consideration_queue_indices( + consideration_queue: Iterable[Set[Hashable]] +) -> Dict[Hashable, int]: + """ + Returns: + A dictionary mapping nodes to their indices in + **consideration_queue** + """ + consideration_queue_indices = {} + for i, cs in enumerate(consideration_queue): + consideration_queue_indices.update({ + n: i for n in cs + }) + return consideration_queue_indices + + class SchedulerError(Exception): def __init__(self, error_value): @@ -433,11 +472,13 @@ def __init__( :param self: :param conditions: (ConditionSet) - a :keyword:`ConditionSet` to be scheduled """ - self.conditions = ConditionSet(conditions) + self.conditions = ConditionSet() + self._last_handled_structural_condition_order = None + + self._graphs = [] + self._consideration_queues = [] + self._consideration_queue_indices = [] - # the consideration queue is the ordered list of sets of nodes in the graph, by the - # order in which they should be checked to ensure that all parents have a chance to run before their children - self.consideration_queue = [] if termination_conds is None: termination_conds = default_termination_conds.copy() else: @@ -450,25 +491,25 @@ def __init__( self.default_absolute_time_unit = _parse_absolute_unit(default_absolute_time_unit) if isinstance(graph, nx.DiGraph): - self.dependency_dict = networkx_digraph_to_dependency_dict(graph) + base_graph = networkx_digraph_to_dependency_dict(graph) elif graph is None or isinstance(graph, nx.Graph): raise SchedulerError( 'Must instantiate a Scheduler with a graph dependency dict or a networkx.DiGraph' ) else: # add empty dependency set for senders that aren't present - self.dependency_dict = { + base_graph = { **{n: set() for n in set().union(*graph.values())}, **clone_graph(graph) } - self.consideration_queue = list(toposort(self.dependency_dict)) - self.nodes = [] - for consideration_set in self.consideration_queue: - for node in consideration_set: - self.nodes.append(node) + self._push_graph(base_graph) + + self.nodes = list(base_graph.keys()) - self._generate_consideration_queue_indices() + # add conditions after initial graph to deal with structural + if conditions is not None: + self.add_condition_set(conditions) self.default_execution_id = default_execution_id # stores the in order list of self.run's yielded outputs @@ -481,13 +522,6 @@ def __init__( self.date_creation = datetime.datetime.now() self.date_last_run_end = None - def _generate_consideration_queue_indices(self): - self.consideration_queue_indices = {} - for i, cs in enumerate(self.consideration_queue): - self.consideration_queue_indices.update({ - n: i for n in cs - }) - def _init_counts(self, execution_id, base_execution_id=NotImplemented): """ Attributes @@ -594,8 +628,14 @@ def _combine_termination_conditions(self, termination_conds): @staticmethod def _parse_termination_conditions(termination_conds): + err_msg = ( + "Termination conditions must be a dictionary of the form" + " {TimeScale: Condition, ...} and cannot include" + " GraphStructureCondition." + ) + # parse string representation of TimeScale - parsed_conds = {} + parsed_conds = termination_conds delkeys = set() for scale in termination_conds: try: @@ -604,21 +644,26 @@ def _parse_termination_conditions(termination_conds): except (AttributeError, TypeError): pass - termination_conds.update(parsed_conds) - try: - termination_conds = { - k: termination_conds[k] for k in termination_conds + parsed_conds = { + k: parsed_conds[k] + for k in parsed_conds if ( isinstance(k, TimeScale) - and isinstance(termination_conds[k], Condition) + and isinstance(parsed_conds[k], Condition) and k not in delkeys ) } except TypeError: - raise TypeError('termination_conditions must be a dictionary of the form {TimeScale: Condition, ...}') + raise TypeError(err_msg) else: - return termination_conds + invalid_conds = { + k: termination_conds[k] + for k in termination_conds.keys() - parsed_conds.keys() - delkeys + } + if len(invalid_conds) > 0: + raise SchedulerError(f"{err_msg} Invalid: {invalid_conds}") + return parsed_conds def end_environment_sequence(self, execution_id=NotImplemented): """Signals that an `ENVIRONMENT_SEQUENCE` has completed @@ -631,6 +676,40 @@ def end_environment_sequence(self, execution_id=NotImplemented): self._increment_time(TimeScale.ENVIRONMENT_SEQUENCE, execution_id) + def add_graph_edge(self, sender: Hashable, receiver: Hashable) -> AddEdgeTo: + """ + Adds an edge to the `graph ` from **sender** to + **receiver**. Equivalent to ``add_condition(sender, + AddEdgeTo(receiver))``. + + Args: + sender (Hashable): sender of the new edge + receiver (Hashable): receiver of the new edge + + Returns: + AddEdgeTo: the new condition added to implement the edge + """ + cond = AddEdgeTo(receiver) + self.add_condition(sender, cond) + return cond + + def remove_graph_edge(self, sender: Hashable, receiver: Hashable) -> RemoveEdgeFrom: + """ + Removes an edge from the `graph ` from + **sender** to **receiver** if it exists. Equivalent to + ``add_condition(receiver, RemoveEdgeFrom(sender))``. + + Args: + sender (Hashable): sender of the edge to be removed + receiver (Hashable): receiver of the edge to be removed + + Returns: + RemoveEdgeFrom: the new condition added to implement the edge + """ + cond = RemoveEdgeFrom(sender) + self.add_condition(receiver, cond) + return cond + ################################################################################ # Wrapper methods # to allow the user to ignore the ConditionSet internals @@ -638,11 +717,21 @@ def end_environment_sequence(self, execution_id=NotImplemented): def __contains__(self, item): return self.conditions.__contains__(item) - def add_condition(self, owner, condition): + def add_condition( + self, owner: Hashable, condition: typing_condition_base + ): """ - Adds a `Condition` to the Scheduler. If **owner** already has a Condition, it is overwritten - with the new one. If you want to add multiple conditions to a single owner, use the - `composite Conditions ` to accurately specify the desired behavior. + Adds a `basic ` or `graph structure + ` Condition to the Scheduler. + + If **condition** is basic, it will overwrite the current basic + Condition for **owner**, if present. If you want to add multiple + basic Conditions to a single owner, instead add a single + `Composite Condition ` to accurately + specify the desired behavior. + + If **condition** is structural, it will be applied on top of + `Scheduler.graph` in the order it is added. Arguments --------- @@ -651,17 +740,27 @@ def add_condition(self, owner, condition): specifies the node with which the **condition** should be associated. **condition** will govern the execution behavior of **owner** - condition : Condition + condition : ConditionBase specifies the Condition, associated with the **owner** to be added to the ConditionSet. """ self.conditions.add_condition(owner, condition) + self._handle_modified_structural_conditions() def add_condition_set(self, conditions): """ - Adds a set of `Conditions ` (in the form of a dict or another ConditionSet) to the Scheduler. - Any Condition added here will overwrite an existing Condition for a given owner. - If you want to add multiple conditions to a single owner, add a single `Composite Condition ` - to accurately specify the desired behavior. + Adds a set of `basic ` or `graph structure + ` Conditions (in the form of a dict or + another ConditionSet) to the Scheduler. + + Any basic Condition added here will overwrite the current basic + Condition for a given owner, if present. If you want to add + multiple basic Conditions to a single owner, instead add a + single `Composite Condition ` to + accurately specify the desired behavior. + + Any structural Condition added here will be applied on top of + `Scheduler.graph` in the order they are returned by iteration + over **conditions**. Arguments --------- @@ -675,19 +774,34 @@ def add_condition_set(self, conditions): """ self.conditions.add_condition_set(conditions) + self._handle_modified_structural_conditions() ################################################################################ # Validation methods # to provide the user with info if they do something odd ################################################################################ def _validate_run_state(self): + try: + _build_consideration_queue(self.graph) + except CircularDependencyError as e: + raise SchedulerError( + f'Cannot run on a graph that contains a cycle: {e}' + ) from e + + if ( + self.consideration_queue == 1 + and self.consideration_queue[0] == set() + and len(self.graph) != 0 + ): + raise SchedulerError('Unexpected empty consideration queue') + self._validate_conditions() def _validate_conditions(self): unspecified_nodes = [] for node in self.nodes: - if node not in self.conditions: - dependencies = list(self.dependency_dict[node]) + if node not in self.conditions.conditions_basic: + dependencies = list(self.graph[node]) if len(dependencies) == 0: cond = Always() elif len(dependencies) == 1: @@ -700,10 +814,19 @@ def _validate_conditions(self): if len(unspecified_nodes) > 0: logger.info( 'These nodes have no Conditions specified, and will be scheduled with conditions: {0}'.format( - {node: self.conditions[node] for node in unspecified_nodes} + {node: self.conditions.conditions_basic[node] for node in unspecified_nodes} ) ) + if ( + self.mode is SchedulingMode.EXACT_TIME + and len(self.conditions.conditions_structural) > 0 + ): + warnings.warn( + 'In exact time mode, graph structure conditions will have no' + f' effect: {self.conditions.conditions_structural}' + ) + ################################################################################ # Run methods ################################################################################ @@ -724,11 +847,6 @@ def run( """ self._validate_run_state() - if self.mode is SchedulingMode.EXACT_TIME: - effective_consideration_queue = [set(self.nodes)] - else: - effective_consideration_queue = self.consideration_queue - if termination_conds is None: termination_conds = self.termination_conds else: @@ -778,14 +896,14 @@ def run( cur_index_consideration_queue = 0 while ( - cur_index_consideration_queue < len(effective_consideration_queue) + cur_index_consideration_queue < len(self.consideration_queue) and not termination_conds[TimeScale.ENVIRONMENT_STATE_UPDATE].is_satisfied(**is_satisfied_kwargs) and not termination_conds[TimeScale.ENVIRONMENT_SEQUENCE].is_satisfied(**is_satisfied_kwargs) ): # all nodes to be added during this consideration set execution cur_consideration_set_execution_exec = set() # the current "layer/group" of nodes that MIGHT be added during this consideration set execution - cur_consideration_set = effective_consideration_queue[cur_index_consideration_queue] + cur_consideration_set = self.consideration_queue[cur_index_consideration_queue] try: iter(cur_consideration_set) @@ -801,7 +919,7 @@ def run( # only add each node once during a single consideration set execution, this also serves # to prevent infinitely cascading adds if current_node not in cur_consideration_set_execution_exec: - if self.conditions.conditions[current_node].is_satisfied(**is_satisfied_kwargs): + if self.conditions.conditions_basic[current_node].is_satisfied(**is_satisfied_kwargs): cur_consideration_set_execution_exec.add(current_node) execution_list_has_changed = True cur_consideration_set_has_changed = True @@ -897,7 +1015,7 @@ def get_absolute_conditions(self, termination_conds=None): return { owner: cond for owner, cond - in [*self.conditions.conditions.items(), *termination_conds.items()] + in [*self.conditions.conditions_basic.items(), *termination_conds.items()] if cond.is_absolute } @@ -942,3 +1060,98 @@ def termination_conds(self, termination_conds): @property def _in_exact_time_mode(self): return self.mode is SchedulingMode.EXACT_TIME or len(self.consideration_queue) == 1 + + def _handle_modified_structural_conditions(self): + if self._last_handled_structural_condition_order != self.conditions.structural_condition_order: + common_index = -1 + for i, cond in enumerate(self.conditions.structural_condition_order): + try: + if self._last_handled_structural_condition_order[i] == cond: + common_index = i + else: + break + except (IndexError, TypeError): + break + + # remove scheduler dependency dicts down to the common + # structural condition + if self._last_handled_structural_condition_order is not None: + # if anything must be done with the popped structures + # and cond, the iteration order should be reversed + for cond in self._last_handled_structural_condition_order[common_index + 1:]: + self._pop_graph() + + # add dependency dicts for new structural conditions + cur_graph = self._graphs[-1] + for cond in self.conditions.structural_condition_order[common_index + 1:]: + cur_graph = cond.modify_graph(cur_graph) + self._push_graph(cur_graph) + + self._last_handled_structural_condition_order = copy.copy( + self.conditions.structural_condition_order + ) + + def _push_graph(self, graph): + try: + consideration_queue = _build_consideration_queue(graph) + except CircularDependencyError as e: + consideration_queue = [set()] + try: + cond = self.conditions.structural_condition_order[-1] + except IndexError: + cond_str = 'Base graph' + else: + cond_str = f'Condition {cond} for {cond.owner} creates a cycle: {e}' + warnings.warn(cond_str) + + self._graphs.append(graph) + self._consideration_queues.append(consideration_queue) + self._consideration_queue_indices.append( + _generate_consideration_queue_indices(consideration_queue) + ) + + def _pop_graph(self): + return ( + self._graphs.pop(), + self._consideration_queues.pop(), + self._consideration_queue_indices.pop(), + ) + + @property + def graph(self) -> typing_graph_dependency_dict: + """ + The current graph used by this Scheduler, which is modified by + any `graph structure conditions + ` added + """ + self._handle_modified_structural_conditions() + return self._graphs[-1] + + # Maintain backwards compatibility for v1.x + @property + def dependency_dict(self) -> typing_graph_dependency_dict: + return self.graph + + @property + def consideration_queue(self) -> List[Set[Hashable]]: + """ + The ordered list of sets of nodes in the graph, by the order in + which they will be checked to ensure that all senders have a + chance to run before their receivers + """ + self._handle_modified_structural_conditions() + + if self.mode is SchedulingMode.EXACT_TIME: + return [set(self.graph.keys())] + else: + return self._consideration_queues[-1] + + @property + def consideration_queue_indices(self) -> Dict[Hashable, int]: + """ + A dictionary mapping the graph's nodes to their position in the + original consideration queue. This is the same as the + consideration queue when not using SchedulingMode.EXACT_TIME. + """ + self._handle_modified_structural_conditions() + return self._consideration_queue_indices[-1] diff --git a/tests/scheduling/test_scheduler.py b/tests/scheduling/test_scheduler.py index 4bd835b..3a1fa05 100644 --- a/tests/scheduling/test_scheduler.py +++ b/tests/scheduling/test_scheduler.py @@ -12,6 +12,12 @@ SimpleTestNode = pytest.helpers.get_test_node() +test_graphs = { + 'three_node_linear': pytest.helpers.create_graph_from_pathways(['A', 'B', 'C']), + 'four_node_split': pytest.helpers.create_graph_from_pathways(['A', 'B', 'D'], ['C', 'D']) +} + + class TestScheduler: stroop_paths = [ ['Color_Input', 'Color_Hidden', 'Output', 'Decision'], @@ -162,6 +168,121 @@ def test_delete_counts(self): assert sched.execution_list[eid_delete] == del_run_1 assert sched.execution_list[eid_repeat] == repeat_run_2 + repeat_run_2 + @pytest.mark.parametrize('add_method', ['add_graph_edge', 'add_condition_AddEdgeTo']) + @pytest.mark.parametrize('remove_method', ['remove_graph_edge', 'add_condition_RemoveEdgeFrom']) + def test_add_graph_structure_conditions(self, add_method, remove_method): + def add_condition(owner, condition): + if isinstance(condition, gs.AddEdgeTo) and add_method == 'add_graph_edge': + return scheduler.add_graph_edge(owner, condition.node) + elif isinstance(condition, gs.RemoveEdgeFrom) and remove_method == 'remove_graph_edge': + return scheduler.remove_graph_edge(condition.node, owner) + else: + scheduler.add_condition(owner, condition) + return condition + + initial_graph = pytest.helpers.create_graph_from_pathways(['A', 'B', 'C', 'D', 'E']) + initial_conds = {'A': gs.AddEdgeTo('C')} + scheduler = gs.Scheduler(initial_graph, initial_conds) + + assert scheduler.dependency_dict == { + **initial_graph, + **{'C': {'A', 'B'}}, + } + assert len(scheduler._graphs) == 2 + assert scheduler._graphs[0] == initial_graph + + addl_conditions = [ + ('B', gs.AddEdgeTo('D')), + ('B', gs.AddEdgeTo('E')), + ('C', gs.AddEdgeTo('E')), + ('E', gs.RemoveEdgeFrom('B')), + ('D', gs.RemoveEdgeFrom('B')), + ] + + for i, (owner, cond) in enumerate(addl_conditions): + added_cond = add_condition(owner, cond) + addl_conditions[i] = (owner, added_cond) + + assert scheduler.dependency_dict == { + 'A': set(), + 'B': {'A'}, + 'C': {'A', 'B'}, + 'D': {'C'}, + 'E': {'C', 'D'}, + } + assert scheduler._last_handled_structural_condition_order == ( + [initial_conds['A']] + [c[1] for c in addl_conditions] + ) + + # take only the first three elements in addl_conditions + addl_conds_sub_idx = 3 + scheduler.conditions = gs.ConditionSet({ + **{ + k: [ + addl_conditions[i][1] for i in range(addl_conds_sub_idx) + if addl_conditions[i][0] == k + ] + for k in initial_graph + }, + 'A': initial_conds['A'], + }) + assert scheduler.dependency_dict == { + 'A': set(), + 'B': {'A'}, + 'C': {'A', 'B'}, + 'D': {'B', 'C'}, + 'E': {'B', 'C', 'D'}, + } + assert scheduler._last_handled_structural_condition_order == ( + [initial_conds['A']] + [c[1] for c in addl_conditions[:addl_conds_sub_idx]] + ) + + @pytest.mark.parametrize( + 'graph_name, conditions, expected_output', + [ + ('three_node_linear', {'C': gs.BeforeNode('A')}, [{'C'}, {'A'}, {'B'}]), + ('three_node_linear', {'B': gs.AfterNodes('C')}, [{'A'}, {'C'}, {'B'}]), + ('four_node_split', {'D': gs.BeforeNodes('A', 'C')}, [{'D'}, {'A', 'C'}, {'B'}]), + ] + ) + def test_run_graph_structure_conditions(self, graph_name, conditions, expected_output): + scheduler = gs.Scheduler(test_graphs[graph_name], conditions) + output = list(scheduler.run()) + + assert output == expected_output + + def test_gsc_creates_cyclic_graph(self): + scheduler = gs.Scheduler( + pytest.helpers.create_graph_from_pathways(['A', 'B', 'C']) + ) + scheduler.add_condition('B', gs.EveryNCalls('A', 1)) + scheduler.add_condition('B', gs.AfterNode('C')) + with pytest.warns(UserWarning, match='for B creates a cycle:'): + scheduler.add_condition('B', gs.BeforeNode('A', prune_cycles=False)) + + # If _build_consideration_queue failure not explicitly detected + # and handled while adding BeforeNode('A') for 'B', the new + # modified cyclic graph is pushed but the condition is not + # added, resulting in incorrect state of scheduler._graphs. + # Assert this doesn't happen. + assert len(scheduler._graphs) == 3 + assert len(scheduler.conditions.structural_condition_order) == 2 + + with pytest.raises(gs.SchedulerError, match='contains a cycle'): + list(scheduler.run()) + + def test_gsc_exact_time_warning(self): + scheduler = gs.Scheduler( + {'A': set(), 'B': set()}, mode=gs.SchedulingMode.EXACT_TIME + ) + scheduler.add_condition('A', gs.AfterNode('B')) + + with pytest.warns( + UserWarning, + match='In exact time mode, graph structure conditions will have no effect' + ): + list(scheduler.run()) + class TestLinear: def test_no_termination_conds(self):