Skip to content

Commit

Permalink
mg wip
Browse files Browse the repository at this point in the history
  • Loading branch information
kmantel committed Dec 5, 2023
1 parent ffb51a8 commit 142ba94
Showing 1 changed file with 45 additions and 55 deletions.
100 changes: 45 additions & 55 deletions tests/scheduling/test_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,13 @@ def test_AtPass_underconstrained(self):
B = 'B'
C = 'C'
sched = gs.Scheduler(graph={A: set(), B: {A}, C: {B}})
sched.add_condition(A, pnl.AtPass(0))
sched.add_condition(B, pnl.Always())
sched.add_condition(C, pnl.Always())
sched.add_condition(A, gs.AtPass(0))
sched.add_condition(B, gs.Always())
sched.add_condition(C, gs.Always())

termination_conds = {}
termination_conds[pnl.TimeScale.ENVIRONMENT_SEQUENCE] = pnl.AfterNEnvironmentStateUpdates(1)
termination_conds[pnl.TimeScale.ENVIRONMENT_STATE_UPDATE] = pnl.AfterNCalls(C, 2)
termination_conds[gs.TimeScale.ENVIRONMENT_SEQUENCE] = gs.AfterNEnvironmentStateUpdates(1)
termination_conds[gs.TimeScale.ENVIRONMENT_STATE_UPDATE] = gs.AfterNCalls(C, 2)
output = list(sched.run(termination_conds=termination_conds))

expected_output = [A, B, C, B, C]
Expand All @@ -265,11 +265,11 @@ def test_AtPass_underconstrained(self):
def test_AtPass_in_middle(self):
A = 'A'
sched = gs.Scheduler(graph={A: set()})
sched.add_condition(A, pnl.AtPass(2))
sched.add_condition(A, gs.AtPass(2))

termination_conds = {}
termination_conds[pnl.TimeScale.ENVIRONMENT_SEQUENCE] = pnl.AfterNEnvironmentStateUpdates(1)
termination_conds[pnl.TimeScale.ENVIRONMENT_STATE_UPDATE] = pnl.AtPass(5)
termination_conds[gs.TimeScale.ENVIRONMENT_SEQUENCE] = gs.AfterNEnvironmentStateUpdates(1)
termination_conds[gs.TimeScale.ENVIRONMENT_STATE_UPDATE] = gs.AtPass(5)
output = list(sched.run(termination_conds=termination_conds))

expected_output = [set(), set(), A, set(), set()]
Expand All @@ -278,11 +278,11 @@ def test_AtPass_in_middle(self):
def test_AtPass_at_end(self):
A = 'A'
sched = gs.Scheduler(graph={A: set()})
sched.add_condition(A, pnl.AtPass(5))
sched.add_condition(A, gs.AtPass(5))

termination_conds = {}
termination_conds[pnl.TimeScale.ENVIRONMENT_SEQUENCE] = pnl.AfterNEnvironmentStateUpdates(1)
termination_conds[pnl.TimeScale.ENVIRONMENT_STATE_UPDATE] = pnl.AtPass(5)
termination_conds[gs.TimeScale.ENVIRONMENT_SEQUENCE] = gs.AfterNEnvironmentStateUpdates(1)
termination_conds[gs.TimeScale.ENVIRONMENT_STATE_UPDATE] = gs.AtPass(5)
output = list(sched.run(termination_conds=termination_conds))

expected_output = [set(), set(), set(), set(), set()]
Expand All @@ -291,11 +291,11 @@ def test_AtPass_at_end(self):
def test_AtPass_after_end(self):
A = 'A'
sched = gs.Scheduler(graph={A: set()})
sched.add_condition(A, pnl.AtPass(6))
sched.add_condition(A, gs.AtPass(6))

termination_conds = {}
termination_conds[pnl.TimeScale.ENVIRONMENT_SEQUENCE] = pnl.AfterNEnvironmentStateUpdates(1)
termination_conds[pnl.TimeScale.ENVIRONMENT_STATE_UPDATE] = pnl.AtPass(5)
termination_conds[gs.TimeScale.ENVIRONMENT_SEQUENCE] = gs.AfterNEnvironmentStateUpdates(1)
termination_conds[gs.TimeScale.ENVIRONMENT_STATE_UPDATE] = gs.AtPass(5)
output = list(sched.run(termination_conds=termination_conds))

expected_output = [set(), set(), set(), set(), set()]
Expand All @@ -304,11 +304,11 @@ def test_AtPass_after_end(self):
def test_AfterPass(self):
A = 'A'
sched = gs.Scheduler(graph={A: set()})
sched.add_condition(A, pnl.AfterPass(0))
sched.add_condition(A, gs.AfterPass(0))

termination_conds = {}
termination_conds[pnl.TimeScale.ENVIRONMENT_SEQUENCE] = pnl.AfterNEnvironmentStateUpdates(1)
termination_conds[pnl.TimeScale.ENVIRONMENT_STATE_UPDATE] = pnl.AtPass(5)
termination_conds[gs.TimeScale.ENVIRONMENT_SEQUENCE] = gs.AfterNEnvironmentStateUpdates(1)
termination_conds[gs.TimeScale.ENVIRONMENT_STATE_UPDATE] = gs.AtPass(5)
output = list(sched.run(termination_conds=termination_conds))

expected_output = [set(), A, A, A, A]
Expand All @@ -317,11 +317,11 @@ def test_AfterPass(self):
def test_AfterNPasses(self):
A = 'A'
sched = gs.Scheduler(graph={A: set()})
sched.add_condition(A, pnl.AfterNPasses(1))
sched.add_condition(A, gs.AfterNPasses(1))

termination_conds = {}
termination_conds[pnl.TimeScale.ENVIRONMENT_SEQUENCE] = pnl.AfterNEnvironmentStateUpdates(1)
termination_conds[pnl.TimeScale.ENVIRONMENT_STATE_UPDATE] = pnl.AtPass(5)
termination_conds[gs.TimeScale.ENVIRONMENT_SEQUENCE] = gs.AfterNEnvironmentStateUpdates(1)
termination_conds[gs.TimeScale.ENVIRONMENT_STATE_UPDATE] = gs.AtPass(5)
output = list(sched.run(termination_conds=termination_conds))

expected_output = [set(), A, A, A, A]
Expand Down Expand Up @@ -592,33 +592,28 @@ def test_AtEnvironmentStateUpdateStart(self):

@pytest.mark.psyneulink
def test_composite_condition_multi(self):
comp = pnl.Composition()
A = pnl.TransferMechanism(name='A')
B = pnl.TransferMechanism(name='B')
C = pnl.TransferMechanism(name='C')
for m in [A, B, C]:
comp.add_node(m)
comp.add_projection(pnl.MappingProjection(), A, B)
comp.add_projection(pnl.MappingProjection(), B, C)
sched = pnl.Scheduler(**pytest.helpers.composition_to_scheduler_args(comp))

sched.add_condition(A, pnl.EveryNPasses(1))
sched.add_condition(B, pnl.EveryNCalls(A, 2))
sched.add_condition(C, pnl.All(
pnl.Any(
pnl.AfterPass(6),
pnl.AfterNCalls(B, 2)
A = 'A'
B = 'B'
C = 'C'
sched = gs.Scheduler(graph={A: set(), B: {A}, C: {B}})

sched.add_condition(A, gs.EveryNPasses(1))
sched.add_condition(B, gs.EveryNCalls(A, 2))
sched.add_condition(C, gs.All(
gs.Any(
gs.AfterPass(6),
gs.AfterNCalls(B, 2)
),
pnl.Any(
pnl.AfterPass(2),
pnl.AfterNCalls(B, 3)
gs.Any(
gs.AfterPass(2),
gs.AfterNCalls(B, 3)
)
)
)

termination_conds = {}
termination_conds[pnl.TimeScale.ENVIRONMENT_SEQUENCE] = pnl.AfterNEnvironmentStateUpdates(1)
termination_conds[pnl.TimeScale.ENVIRONMENT_STATE_UPDATE] = pnl.AfterNCalls(C, 3)
termination_conds[gs.TimeScale.ENVIRONMENT_SEQUENCE] = gs.AfterNEnvironmentStateUpdates(1)
termination_conds[gs.TimeScale.ENVIRONMENT_STATE_UPDATE] = gs.AfterNCalls(C, 3)
output = list(sched.run(termination_conds=termination_conds))
expected_output = [
A, A, B, A, A, B, C, A, C, A, B, C
Expand Down Expand Up @@ -683,23 +678,18 @@ def test_AllHaveRun(self):

@pytest.mark.psyneulink
def test_AllHaveRun_2(self):
comp = pnl.Composition()
A = pnl.TransferMechanism(name='A')
B = pnl.TransferMechanism(name='B')
C = pnl.TransferMechanism(name='C')
for m in [A, B, C]:
comp.add_node(m)
comp.add_projection(pnl.MappingProjection(), A, B)
comp.add_projection(pnl.MappingProjection(), B, C)
sched = pnl.Scheduler(**pytest.helpers.composition_to_scheduler_args(comp))
A = 'A'
B = 'B'
C = 'C'
sched = gs.Scheduler(graph={A: set(), B: {A}, C: {B}})

sched.add_condition(A, pnl.EveryNPasses(1))
sched.add_condition(B, pnl.EveryNCalls(A, 2))
sched.add_condition(C, pnl.EveryNCalls(B, 2))
sched.add_condition(A, gs.EveryNPasses(1))
sched.add_condition(B, gs.EveryNCalls(A, 2))
sched.add_condition(C, gs.EveryNCalls(B, 2))

termination_conds = {}
termination_conds[pnl.TimeScale.ENVIRONMENT_SEQUENCE] = pnl.AfterNEnvironmentStateUpdates(1)
termination_conds[pnl.TimeScale.ENVIRONMENT_STATE_UPDATE] = pnl.AllHaveRun()
termination_conds[gs.TimeScale.ENVIRONMENT_SEQUENCE] = gs.AfterNEnvironmentStateUpdates(1)
termination_conds[gs.TimeScale.ENVIRONMENT_STATE_UPDATE] = gs.AllHaveRun()
output = list(sched.run(termination_conds=termination_conds))

expected_output = [
Expand Down

0 comments on commit 142ba94

Please sign in to comment.