From 142ba94edba7d786d1322b2fe587593b895e78c9 Mon Sep 17 00:00:00 2001 From: Katherine Mantel Date: Tue, 5 Dec 2023 02:16:45 +0000 Subject: [PATCH] mg wip --- tests/scheduling/test_condition.py | 100 +++++++++++++---------------- 1 file changed, 45 insertions(+), 55 deletions(-) diff --git a/tests/scheduling/test_condition.py b/tests/scheduling/test_condition.py index d8c8a7fb..3d0969b3 100644 --- a/tests/scheduling/test_condition.py +++ b/tests/scheduling/test_condition.py @@ -250,13 +250,13 @@ def test_AtPass_underconstrained(self): B = 'B' C = 'C' sched = gs.Scheduler(graph={A: set(), B: {A}, C: {B}}) - sched.add_condition(A, pnl.AtPass(0)) - sched.add_condition(B, pnl.Always()) - sched.add_condition(C, pnl.Always()) + sched.add_condition(A, gs.AtPass(0)) + sched.add_condition(B, gs.Always()) + sched.add_condition(C, gs.Always()) termination_conds = {} - termination_conds[pnl.TimeScale.ENVIRONMENT_SEQUENCE] = pnl.AfterNEnvironmentStateUpdates(1) - termination_conds[pnl.TimeScale.ENVIRONMENT_STATE_UPDATE] = pnl.AfterNCalls(C, 2) + termination_conds[gs.TimeScale.ENVIRONMENT_SEQUENCE] = gs.AfterNEnvironmentStateUpdates(1) + termination_conds[gs.TimeScale.ENVIRONMENT_STATE_UPDATE] = gs.AfterNCalls(C, 2) output = list(sched.run(termination_conds=termination_conds)) expected_output = [A, B, C, B, C] @@ -265,11 +265,11 @@ def test_AtPass_underconstrained(self): def test_AtPass_in_middle(self): A = 'A' sched = gs.Scheduler(graph={A: set()}) - sched.add_condition(A, pnl.AtPass(2)) + sched.add_condition(A, gs.AtPass(2)) termination_conds = {} - termination_conds[pnl.TimeScale.ENVIRONMENT_SEQUENCE] = pnl.AfterNEnvironmentStateUpdates(1) - termination_conds[pnl.TimeScale.ENVIRONMENT_STATE_UPDATE] = pnl.AtPass(5) + termination_conds[gs.TimeScale.ENVIRONMENT_SEQUENCE] = gs.AfterNEnvironmentStateUpdates(1) + termination_conds[gs.TimeScale.ENVIRONMENT_STATE_UPDATE] = gs.AtPass(5) output = list(sched.run(termination_conds=termination_conds)) expected_output = [set(), set(), A, set(), set()] @@ -278,11 +278,11 @@ def test_AtPass_in_middle(self): def test_AtPass_at_end(self): A = 'A' sched = gs.Scheduler(graph={A: set()}) - sched.add_condition(A, pnl.AtPass(5)) + sched.add_condition(A, gs.AtPass(5)) termination_conds = {} - termination_conds[pnl.TimeScale.ENVIRONMENT_SEQUENCE] = pnl.AfterNEnvironmentStateUpdates(1) - termination_conds[pnl.TimeScale.ENVIRONMENT_STATE_UPDATE] = pnl.AtPass(5) + termination_conds[gs.TimeScale.ENVIRONMENT_SEQUENCE] = gs.AfterNEnvironmentStateUpdates(1) + termination_conds[gs.TimeScale.ENVIRONMENT_STATE_UPDATE] = gs.AtPass(5) output = list(sched.run(termination_conds=termination_conds)) expected_output = [set(), set(), set(), set(), set()] @@ -291,11 +291,11 @@ def test_AtPass_at_end(self): def test_AtPass_after_end(self): A = 'A' sched = gs.Scheduler(graph={A: set()}) - sched.add_condition(A, pnl.AtPass(6)) + sched.add_condition(A, gs.AtPass(6)) termination_conds = {} - termination_conds[pnl.TimeScale.ENVIRONMENT_SEQUENCE] = pnl.AfterNEnvironmentStateUpdates(1) - termination_conds[pnl.TimeScale.ENVIRONMENT_STATE_UPDATE] = pnl.AtPass(5) + termination_conds[gs.TimeScale.ENVIRONMENT_SEQUENCE] = gs.AfterNEnvironmentStateUpdates(1) + termination_conds[gs.TimeScale.ENVIRONMENT_STATE_UPDATE] = gs.AtPass(5) output = list(sched.run(termination_conds=termination_conds)) expected_output = [set(), set(), set(), set(), set()] @@ -304,11 +304,11 @@ def test_AtPass_after_end(self): def test_AfterPass(self): A = 'A' sched = gs.Scheduler(graph={A: set()}) - sched.add_condition(A, pnl.AfterPass(0)) + sched.add_condition(A, gs.AfterPass(0)) termination_conds = {} - termination_conds[pnl.TimeScale.ENVIRONMENT_SEQUENCE] = pnl.AfterNEnvironmentStateUpdates(1) - termination_conds[pnl.TimeScale.ENVIRONMENT_STATE_UPDATE] = pnl.AtPass(5) + termination_conds[gs.TimeScale.ENVIRONMENT_SEQUENCE] = gs.AfterNEnvironmentStateUpdates(1) + termination_conds[gs.TimeScale.ENVIRONMENT_STATE_UPDATE] = gs.AtPass(5) output = list(sched.run(termination_conds=termination_conds)) expected_output = [set(), A, A, A, A] @@ -317,11 +317,11 @@ def test_AfterPass(self): def test_AfterNPasses(self): A = 'A' sched = gs.Scheduler(graph={A: set()}) - sched.add_condition(A, pnl.AfterNPasses(1)) + sched.add_condition(A, gs.AfterNPasses(1)) termination_conds = {} - termination_conds[pnl.TimeScale.ENVIRONMENT_SEQUENCE] = pnl.AfterNEnvironmentStateUpdates(1) - termination_conds[pnl.TimeScale.ENVIRONMENT_STATE_UPDATE] = pnl.AtPass(5) + termination_conds[gs.TimeScale.ENVIRONMENT_SEQUENCE] = gs.AfterNEnvironmentStateUpdates(1) + termination_conds[gs.TimeScale.ENVIRONMENT_STATE_UPDATE] = gs.AtPass(5) output = list(sched.run(termination_conds=termination_conds)) expected_output = [set(), A, A, A, A] @@ -592,33 +592,28 @@ def test_AtEnvironmentStateUpdateStart(self): @pytest.mark.psyneulink def test_composite_condition_multi(self): - comp = pnl.Composition() - A = pnl.TransferMechanism(name='A') - B = pnl.TransferMechanism(name='B') - C = pnl.TransferMechanism(name='C') - for m in [A, B, C]: - comp.add_node(m) - comp.add_projection(pnl.MappingProjection(), A, B) - comp.add_projection(pnl.MappingProjection(), B, C) - sched = pnl.Scheduler(**pytest.helpers.composition_to_scheduler_args(comp)) - - sched.add_condition(A, pnl.EveryNPasses(1)) - sched.add_condition(B, pnl.EveryNCalls(A, 2)) - sched.add_condition(C, pnl.All( - pnl.Any( - pnl.AfterPass(6), - pnl.AfterNCalls(B, 2) + A = 'A' + B = 'B' + C = 'C' + sched = gs.Scheduler(graph={A: set(), B: {A}, C: {B}}) + + sched.add_condition(A, gs.EveryNPasses(1)) + sched.add_condition(B, gs.EveryNCalls(A, 2)) + sched.add_condition(C, gs.All( + gs.Any( + gs.AfterPass(6), + gs.AfterNCalls(B, 2) ), - pnl.Any( - pnl.AfterPass(2), - pnl.AfterNCalls(B, 3) + gs.Any( + gs.AfterPass(2), + gs.AfterNCalls(B, 3) ) ) ) termination_conds = {} - termination_conds[pnl.TimeScale.ENVIRONMENT_SEQUENCE] = pnl.AfterNEnvironmentStateUpdates(1) - termination_conds[pnl.TimeScale.ENVIRONMENT_STATE_UPDATE] = pnl.AfterNCalls(C, 3) + termination_conds[gs.TimeScale.ENVIRONMENT_SEQUENCE] = gs.AfterNEnvironmentStateUpdates(1) + termination_conds[gs.TimeScale.ENVIRONMENT_STATE_UPDATE] = gs.AfterNCalls(C, 3) output = list(sched.run(termination_conds=termination_conds)) expected_output = [ A, A, B, A, A, B, C, A, C, A, B, C @@ -683,23 +678,18 @@ def test_AllHaveRun(self): @pytest.mark.psyneulink def test_AllHaveRun_2(self): - comp = pnl.Composition() - A = pnl.TransferMechanism(name='A') - B = pnl.TransferMechanism(name='B') - C = pnl.TransferMechanism(name='C') - for m in [A, B, C]: - comp.add_node(m) - comp.add_projection(pnl.MappingProjection(), A, B) - comp.add_projection(pnl.MappingProjection(), B, C) - sched = pnl.Scheduler(**pytest.helpers.composition_to_scheduler_args(comp)) + A = 'A' + B = 'B' + C = 'C' + sched = gs.Scheduler(graph={A: set(), B: {A}, C: {B}}) - sched.add_condition(A, pnl.EveryNPasses(1)) - sched.add_condition(B, pnl.EveryNCalls(A, 2)) - sched.add_condition(C, pnl.EveryNCalls(B, 2)) + sched.add_condition(A, gs.EveryNPasses(1)) + sched.add_condition(B, gs.EveryNCalls(A, 2)) + sched.add_condition(C, gs.EveryNCalls(B, 2)) termination_conds = {} - termination_conds[pnl.TimeScale.ENVIRONMENT_SEQUENCE] = pnl.AfterNEnvironmentStateUpdates(1) - termination_conds[pnl.TimeScale.ENVIRONMENT_STATE_UPDATE] = pnl.AllHaveRun() + termination_conds[gs.TimeScale.ENVIRONMENT_SEQUENCE] = gs.AfterNEnvironmentStateUpdates(1) + termination_conds[gs.TimeScale.ENVIRONMENT_STATE_UPDATE] = gs.AllHaveRun() output = list(sched.run(termination_conds=termination_conds)) expected_output = [