Skip to content

Commit

Permalink
Scheduler: update to support GraphStructureCondition
Browse files Browse the repository at this point in the history
  • Loading branch information
kmantel committed Oct 26, 2023
1 parent dd87858 commit 573dd4c
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 43 deletions.
194 changes: 151 additions & 43 deletions src/graph_scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,9 @@
import datetime
import enum
import fractions
import functools
import logging
from typing import Union
from typing import Dict, Hashable, Union

import networkx as nx
import numpy as np
Expand All @@ -311,11 +312,11 @@

from graph_scheduler import _unit_registry
from graph_scheduler.condition import (
All, AllHaveRun, Always, Condition, ConditionSet, EveryNCalls, Never,
_parse_absolute_unit, _quantity_as_integer,
AddEdgeTo, All, AllHaveRun, Always, Condition, ConditionSet, EveryNCalls,
Never, RemoveEdgeFrom, _parse_absolute_unit, _quantity_as_integer,
)
from graph_scheduler.time import _get_pint_unit, Clock, TimeScale
from graph_scheduler.utilities import clone_graph
from graph_scheduler.utilities import cached_graph_function, clone_graph

__all__ = [
'Scheduler', 'SchedulerError', 'SchedulingMode',
Expand All @@ -342,6 +343,20 @@ class SchedulingMode(enum.Enum):
EXACT_TIME = enum.auto()


@cached_graph_function
def _build_consideration_queue(dependency_dict):
return list(toposort(dependency_dict))


def _generate_consideration_queue_indices(consideration_queue):
consideration_queue_indices = {}
for i, cs in enumerate(consideration_queue):
consideration_queue_indices.update({
n: i for n in cs
})
return consideration_queue_indices


class SchedulerError(Exception):

def __init__(self, error_value):
Expand Down Expand Up @@ -427,11 +442,13 @@ def __init__(
:param self:
:param conditions: (ConditionSet) - a :keyword:`ConditionSet` to be scheduled
"""
self.conditions = ConditionSet(conditions)
self.conditions = ConditionSet()
self._last_handled_structural_condition_order = None

self._dependency_dicts = []
self._consideration_queues = []
self._consideration_queue_indices = []

# the consideration queue is the ordered list of sets of nodes in the graph, by the
# order in which they should be checked to ensure that all parents have a chance to run before their children
self.consideration_queue = []
if termination_conds is None:
termination_conds = default_termination_conds.copy()
else:
Expand All @@ -450,28 +467,28 @@ def __init__(
)

if isinstance(graph, nx.Graph):
self.dependency_dict = {}
base_dependency_dict = {}
for sender, receivers in graph.adj.items():
if sender not in self.dependency_dict:
self.dependency_dict[sender] = set()
if sender not in base_dependency_dict:
base_dependency_dict[sender] = set()
for rec in receivers:
if rec not in self.dependency_dict:
self.dependency_dict[rec] = set()
self.dependency_dict[rec].add(sender)
if rec not in base_dependency_dict:
base_dependency_dict[rec] = set()
base_dependency_dict[rec].add(sender)
else:
# add empty dependency set for senders that aren't present
self.dependency_dict = {
base_dependency_dict = {
**{n: set() for n in set().union(*graph.values())},
**clone_graph(graph)
}

self.consideration_queue = list(toposort(self.dependency_dict))
self.nodes = []
for consideration_set in self.consideration_queue:
for node in consideration_set:
self.nodes.append(node)
self._push_dependency_dict(base_dependency_dict)

self._generate_consideration_queue_indices()
self.nodes = list(base_dependency_dict.keys())

# add conditions after initial graph to deal with structural
if conditions is not None:
self.add_condition_set(conditions)

self.default_execution_id = default_execution_id
# stores the in order list of self.run's yielded outputs
Expand All @@ -484,13 +501,6 @@ def __init__(
self.date_creation = datetime.datetime.now()
self.date_last_run_end = None

def _generate_consideration_queue_indices(self):
self.consideration_queue_indices = {}
for i, cs in enumerate(self.consideration_queue):
self.consideration_queue_indices.update({
n: i for n in cs
})

def _init_counts(self, execution_id, base_execution_id=NotImplemented):
"""
Attributes
Expand Down Expand Up @@ -597,8 +607,14 @@ def _combine_termination_conditions(self, termination_conds):

@staticmethod
def _parse_termination_conditions(termination_conds):
err_msg = (
"Termination conditions must be a dictionary of the form"
" {TimeScale: Condition, ...} and cannot include"
" GraphStructureCondition."
)

# parse string representation of TimeScale
parsed_conds = {}
parsed_conds = termination_conds
delkeys = set()
for scale in termination_conds:
try:
Expand All @@ -607,21 +623,26 @@ def _parse_termination_conditions(termination_conds):
except (AttributeError, TypeError):
pass

termination_conds.update(parsed_conds)

try:
termination_conds = {
k: termination_conds[k] for k in termination_conds
parsed_conds = {
k: parsed_conds[k]
for k in parsed_conds
if (
isinstance(k, TimeScale)
and isinstance(termination_conds[k], Condition)
and isinstance(parsed_conds[k], Condition)
and k not in delkeys
)
}
except TypeError:
raise TypeError('termination_conditions must be a dictionary of the form {TimeScale: Condition, ...}')
raise TypeError(err_msg)
else:
return termination_conds
invalid_conds = {
k: termination_conds[k]
for k in termination_conds.keys() - parsed_conds.keys() - delkeys
}
if len(invalid_conds) > 0:
raise SchedulerError(f"{err_msg} Invalid: {invalid_conds}")
return parsed_conds

def end_environment_sequence(self, execution_id=NotImplemented):
"""Signals that an `ENVIRONMENT_SEQUENCE` has completed
Expand All @@ -634,6 +655,12 @@ def end_environment_sequence(self, execution_id=NotImplemented):

self._increment_time(TimeScale.ENVIRONMENT_SEQUENCE, execution_id)

def add_graph_edge(self, sender, receiver):
self.add_condition(sender, AddEdgeTo(receiver))

def remove_graph_edge(self, sender, receiver):
self.add_condition(receiver, RemoveEdgeFrom(sender))

################################################################################
# Wrapper methods
# to allow the user to ignore the ConditionSet internals
Expand All @@ -657,6 +684,12 @@ def add_condition(self, owner, condition):
condition : Condition
specifies the Condition, associated with the **owner** to be added to the ConditionSet.
"""

# TODO: validation
# - check structural condition not added during exact time mode
# - check owner and subject nodes in scheduler nodes (all conditions?? if not, absolutely for graph structure conds)
# TODO:
# - set a flag indicating the graph structure needs to be rebuilt if GSC added
self.conditions.add_condition(owner, condition)

def add_condition_set(self, conditions):
Expand Down Expand Up @@ -684,6 +717,9 @@ def add_condition_set(self, conditions):
# to provide the user with info if they do something odd
################################################################################
def _validate_run_state(self):
# TODO:
# - check if graph is acyclic, if not, see if any graph structure conditions did it
# - check if run previously using a different effective graph and send a warning (?)
self._validate_conditions()

def _validate_conditions(self):
Expand Down Expand Up @@ -727,11 +763,6 @@ def run(
"""
self._validate_run_state()

if self.mode is SchedulingMode.EXACT_TIME:
effective_consideration_queue = [set(self.nodes)]
else:
effective_consideration_queue = self.consideration_queue

if termination_conds is None:
termination_conds = self.termination_conds
else:
Expand Down Expand Up @@ -781,14 +812,14 @@ def run(
cur_index_consideration_queue = 0

while (
cur_index_consideration_queue < len(effective_consideration_queue)
cur_index_consideration_queue < len(self.consideration_queue)
and not termination_conds[TimeScale.ENVIRONMENT_STATE_UPDATE].is_satisfied(**is_satisfied_kwargs)
and not termination_conds[TimeScale.ENVIRONMENT_SEQUENCE].is_satisfied(**is_satisfied_kwargs)
):
# all nodes to be added during this consideration set execution
cur_consideration_set_execution_exec = set()
# the current "layer/group" of nodes that MIGHT be added during this consideration set execution
cur_consideration_set = effective_consideration_queue[cur_index_consideration_queue]
cur_consideration_set = self.consideration_queue[cur_index_consideration_queue]

try:
iter(cur_consideration_set)
Expand Down Expand Up @@ -942,3 +973,80 @@ def termination_conds(self, termination_conds):
@property
def _in_exact_time_mode(self):
return self.mode is SchedulingMode.EXACT_TIME or len(self.consideration_queue) == 1

def _handle_modified_structural_conditions(self):
if self._last_handled_structural_condition_order != self.conditions.structural_condition_order:
common_index = -1
for i, cond in enumerate(self.conditions.structural_condition_order):
try:
if self._last_handled_structural_condition_order[i] == cond:
common_index = i
else:
break
except (IndexError, TypeError):
break

# remove scheduler dependency dicts down to the common
# structural condition
if self._last_handled_structural_condition_order is not None:
# if anything must be done with the popped structures
# and cond, the iteration order should be reversed
for cond in self._last_handled_structural_condition_order[common_index + 1:]:
self._pop_dependency_dict()

# add dependency dicts for new structural conditions
cur_dependency_dict = self._dependency_dicts[-1]
for cond in self.conditions.structural_condition_order[common_index + 1:]:
cur_dependency_dict = cond.modify_graph(cur_dependency_dict)
self._push_dependency_dict(cur_dependency_dict)

self._last_handled_structural_condition_order = copy.copy(self.conditions.structural_condition_order)

def _push_dependency_dict(self, dependency_dict):
self._dependency_dicts.append(dependency_dict)

consideration_queue = _build_consideration_queue(dependency_dict)
self._consideration_queues.append(consideration_queue)
self._consideration_queue_indices.append(
_generate_consideration_queue_indices(consideration_queue)
)

def _pop_dependency_dict(self):
return (
self._dependency_dicts.pop(),
self._consideration_queues.pop(),
self._consideration_queue_indices.pop(),
)

@property
def dependency_dict(self):
self._handle_modified_structural_conditions()
return self._dependency_dicts[-1]

@property
def consideration_queue(self):
"""
The ordered list of sets of nodes in the graph, by the order in
which they will be checked to ensure that all senders have a
chance to run before their receivers
"""
self._handle_modified_structural_conditions()

if self.mode is SchedulingMode.EXACT_TIME:
return [set(self.dependency_dict.keys())]
else:
return self._consideration_queues[-1]

@property
def consideration_queue_indices(self) -> Dict[Hashable, int]:
"""
A dictionary mapping the graph's nodes to their position in
the original consideration queue. This is the same as the
consideration queue when not using
SchedulingMode.EXACT_TIME.
Returns:
Dict[Hashable, int]
"""
self._handle_modified_structural_conditions()
return self._consideration_queue_indices[-1]
57 changes: 57 additions & 0 deletions tests/scheduling/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,63 @@ def test_delete_counts(self):
assert sched.execution_list[eid_delete] == del_run_1
assert sched.execution_list[eid_repeat] == repeat_run_2 + repeat_run_2

def test_add_structural_conditions(self):
initial_graph = pytest.helpers.gen_linear_graph('A', 'B', 'C', 'D', 'E')
initial_conds = {'A': gs.AddEdgeTo('C')}
scheduler = gs.Scheduler(initial_graph, initial_conds)

assert scheduler.dependency_dict == {
**initial_graph,
**{'C': {'A', 'B'}},
}
assert len(scheduler._dependency_dicts) == 2
assert scheduler._dependency_dicts[0] == initial_graph

addl_conditions = [
('B', gs.AddEdgeTo('D')),
('B', gs.AddEdgeTo('E')),
('C', gs.AddEdgeTo('E')),
('E', gs.RemoveEdgeFrom('B')),
('D', gs.RemoveEdgeFrom('B')),
]

for owner, cond in addl_conditions:
scheduler.add_condition(owner, cond)

assert scheduler.dependency_dict == {
'A': set(),
'B': {'A'},
'C': {'A', 'B'},
'D': {'C'},
'E': {'C', 'D'},
}
assert scheduler._last_handled_structural_condition_order == (
[initial_conds['A']] + [c[1] for c in addl_conditions]
)

# take only the first three elements in addl_conditions
addl_conds_sub_idx = 3
scheduler.conditions = gs.ConditionSet({
**{
k: [
addl_conditions[i][1] for i in range(addl_conds_sub_idx)
if addl_conditions[i][0] == k
]
for k in initial_graph
},
'A': initial_conds['A'],
})
assert scheduler.dependency_dict == {
'A': set(),
'B': {'A'},
'C': {'A', 'B'},
'D': {'B', 'C'},
'E': {'B', 'C', 'D'},
}
assert scheduler._last_handled_structural_condition_order == (
[initial_conds['A']] + [c[1] for c in addl_conditions[:addl_conds_sub_idx]]
)


@pytest.mark.psyneulink
class TestLinear:
Expand Down

0 comments on commit 573dd4c

Please sign in to comment.