Skip to content

Commit

Permalink
more coverage and fixes
Browse files Browse the repository at this point in the history
- fix dispatching in asyncio
- fix graph initialization for added models
  • Loading branch information
aleneum committed Mar 18, 2020
1 parent f59c595 commit d3363b9
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 18 deletions.
84 changes: 83 additions & 1 deletion tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from unittest.mock import MagicMock
from unittest import skipIf
from .test_core import TestTransitions
from .test_core import TestTransitions, MachineError
from .utils import DummyModel


@skipIf(asyncio is None, "AsyncMachine requires asyncio and contextvars suppport")
Expand Down Expand Up @@ -106,6 +107,87 @@ async def fix():
assert m2.is_C()
loop.close()

def test_async_callback_arguments(self):

async def process(should_fail=True):
if should_fail is not False:
raise ValueError("should_fail has been set")

self.machine.on_enter_B(process)
with self.assertRaises(ValueError):
asyncio.run(self.machine.go())
asyncio.run(self.machine.to_A())
asyncio.run(self.machine.go(should_fail=False))

def test_async_callback_event_data(self):

state_a = self.machine_cls.state_cls('A')
state_b = self.machine_cls.state_cls('B')

def sync_condition(event_data):
return event_data.state == state_a

async def process(event_data):
self.assertEqual(event_data.state, state_b)

m = self.machine_cls(states=[state_a, state_b], initial='A', send_event=True)
m.add_transition('go', 'A', 'B', conditions=sync_condition, after=process)
m.add_transition('go', 'B', 'A', conditions=sync_condition)
asyncio.run(m.go())
self.assertTrue(m.is_B())
asyncio.run(m.go())
self.assertTrue(m.is_B())

def test_async_invalid_triggers(self):
asyncio.run(self.machine.to_B())
with self.assertRaises(MachineError):
asyncio.run(self.machine.go())
self.machine.ignore_invalid_triggers = True
asyncio.run(self.machine.go())
self.assertTrue(self.machine.is_B())

def test_async_dispatch(self):
model1 = DummyModel()
model2 = DummyModel()
model3 = DummyModel()

machine = self.machine_cls(model=None, states=['A', 'B', 'C'], transitions=[['go', 'A', 'B'],
['go', 'B', 'C'],
['go', 'C', 'A']], initial='A')
machine.add_model(model1)
machine.add_model(model2, initial='B')
machine.add_model(model3, initial='C')
asyncio.run(machine.dispatch('go'))
self.assertTrue(model1.is_B())
self.assertEqual('C', model2.state)
self.assertEqual(machine.initial, model3.state)

def test_queued(self):
states = ['A', 'B', 'C', 'D']
# Define with list of dictionaries

async def change_state(machine):
self.assertEqual(machine.state, 'A')
if machine.has_queue:
await machine.run(machine=machine)
self.assertEqual(machine.state, 'A')
else:
with self.assertRaises(MachineError):
await machine.run(machine=machine)

transitions = [
{'trigger': 'walk', 'source': 'A', 'dest': 'B', 'before': change_state},
{'trigger': 'run', 'source': 'B', 'dest': 'C'},
{'trigger': 'sprint', 'source': 'C', 'dest': 'D'}
]

m = self.machine_cls(states=states, transitions=transitions, initial='A')
asyncio.run(m.walk(machine=m))
self.assertEqual(m.state, 'B')
m = self.machine_cls(states=states, transitions=transitions, initial='A', queued=True)
asyncio.run(m.walk(machine=m))
self.assertEqual(m.state, 'C')


class AsyncGraphMachine(TestAsync):

Expand Down
8 changes: 4 additions & 4 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_init_machine_with_hella_arguments(self):
}
]
s = Stuff()
m = Machine(model=s, states=states, transitions=transitions, initial='State2')
m = s.machine_cls(model=s, states=states, transitions=transitions, initial='State2')
s.advance()
self.assertEqual(s.message, 'Hello World!')

Expand All @@ -68,11 +68,11 @@ def test_property_initial(self):
{'trigger': 'run', 'source': 'B', 'dest': 'C'},
{'trigger': 'sprint', 'source': 'C', 'dest': 'D'}
]
m = Machine(states=states, transitions=transitions, initial='A')
m = self.stuff.machine_cls(states=states, transitions=transitions, initial='A')
self.assertEqual(m.initial, 'A')
m = Machine(states=states, transitions=transitions, initial='C')
m = self.stuff.machine_cls(states=states, transitions=transitions, initial='C')
self.assertEqual(m.initial, 'C')
m = Machine(states=states, transitions=transitions)
m = self.stuff.machine_cls(states=states, transitions=transitions)
self.assertEqual(m.initial, 'initial')

def test_transition_definitions(self):
Expand Down
8 changes: 4 additions & 4 deletions tests/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
@skipIf(enum is None, "enum is not available")
class TestEnumsAsStates(TestCase):

machine_cls = MachineFactory.get_predefined()

def setUp(self):
class States(enum.Enum):
RED = 1
YELLOW = 2
GREEN = 3
self.machine_cls = MachineFactory.get_predefined()
self.States = States

def test_pass_enums_as_states(self):
Expand Down Expand Up @@ -122,7 +121,9 @@ def goodbye(self):
@skipIf(enum is None, "enum is not available")
class TestNestedStateEnums(TestEnumsAsStates):

machine_cls = MachineFactory.get_predefined(nested=True)
def setUp(self):
super(TestNestedStateEnums, self).setUp()
self.machine_cls = MachineFactory.get_predefined(nested=True)

def test_root_enums(self):
states = [self.States.RED, self.States.YELLOW,
Expand All @@ -132,7 +133,6 @@ def test_root_enums(self):
self.assertTrue(m.is_GREEN_tick())
m.to_RED()
self.assertTrue(m.state is self.States.RED)
# self.assertEqual(m.state, self.States.GREEN)

def test_nested_enums(self):
states = ['A', self.States.GREEN,
Expand Down
9 changes: 8 additions & 1 deletion tests/test_graphviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
except ImportError:
pass

from .utils import Stuff
from .utils import Stuff, DummyModel
from .test_core import TestTransitions

from transitions.extensions import MachineFactory
Expand Down Expand Up @@ -88,6 +88,13 @@ def test_diagram(self):
target.close()
os.unlink(target.name)

def test_transition_custom_model(self):
m = self.machine_cls(model=None, states=self.states, transitions=self.transitions, initial='A',
auto_transitions=False, title='a test', use_pygraphviz=self.use_pygraphviz)
model = DummyModel()
m.add_model(model)
model.walk()

def test_add_custom_state(self):
m = self.machine_cls(states=self.states, transitions=self.transitions, initial='A', auto_transitions=False,
title='a test', use_pygraphviz=self.use_pygraphviz)
Expand Down
22 changes: 21 additions & 1 deletion tests/test_nesting_legacy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .test_nesting import TestTransitions, Stuff
from .test_nesting import TestTransitions, Stuff, default_separator
from .test_reuse import TestTransitions as TestReuse
from .test_enum import TestNestedStateEnums
from transitions.extensions.nesting_legacy import HierarchicalMachine

try:
Expand Down Expand Up @@ -107,3 +108,22 @@ def __init__(self):
self.assertTrue(top_machine.nested.mock.called)
self.assertIsNot(top_machine.nested.get_state('2').on_enter,
top_machine.get_state('B{0}2'.format(separator)).on_enter)


class TestLegacyNestedEnum(TestNestedStateEnums):

def setUp(self):
super(TestLegacyNestedEnum, self).setUp()
self.machine_cls = HierarchicalMachine
self.machine_cls.state_cls.separator = default_separator

def test_nested_enums(self):
# Nested enums are currently not support since model.state does not contain any information about parents
# and nesting
states = ['A', 'B',
{'name': 'C', 'children': self.States, 'initial': self.States.GREEN}]
with self.assertRaises(AttributeError):
# NestedState will raise an error when parent is not None and state name is an enum
# Initializing this would actually work but `m.to_A()` would raise an error in get_state(m.state)
# as Machine is not aware of the location of States.GREEN
m = self.machine_cls(states=states, initial='C')
2 changes: 1 addition & 1 deletion transitions/extensions/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ async def dispatch(self, trigger, *args, **kwargs): # ToDo: not tested
Returns:
bool The truth value of all triggers combined with AND
"""
return asyncio.gather(*[getattr(model, trigger)(*args, **kwargs) for model in self.models])
await asyncio.gather(*[getattr(model, trigger)(*args, **kwargs) for model in self.models])

async def callbacks(self, funcs, event_data):
""" Triggers a list of callbacks """
Expand Down
17 changes: 11 additions & 6 deletions transitions/extensions/diagrams.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from transitions import Transition
from transitions.extensions.markup import MarkupMachine
from transitions.core import listify

import warnings
import logging
Expand Down Expand Up @@ -139,12 +140,6 @@ def __init__(self, *args, **kwargs):
_LOGGER.debug("Using graph engine %s", self.graph_cls)
_super(GraphMachine, self).__init__(*args, **kwargs)

# Create graph at beginning
for model in self.models:
if hasattr(model, 'get_graph'):
raise AttributeError('Model already has a get_graph attribute. Graph retrieval cannot be bound.')
setattr(model, 'get_graph', partial(self._get_graph, model))
_ = model.get_graph(title=self.title, force_new=True) # initialises graph
# for backwards compatibility assign get_combined_graph to get_graph
# if model is not the machine
if not hasattr(self, 'get_graph'):
Expand Down Expand Up @@ -199,6 +194,16 @@ def get_combined_graph(self, title=None, force_new=False, show_roi=False):
'method will return a combined graph of all models.')
return self._get_graph(self.models[0], title, force_new, show_roi)

def add_model(self, model, initial=None):
models = listify(model)
super(GraphMachine, self).add_model(models, initial)
for mod in models:
mod = self if mod == 'self' else mod
if hasattr(mod, 'get_graph'):
raise AttributeError('Model already has a get_graph attribute. Graph retrieval cannot be bound.')
setattr(mod, 'get_graph', partial(self._get_graph, mod))
_ = mod.get_graph(title=self.title, force_new=True) # initialises graph

def add_states(self, states, on_enter=None, on_exit=None,
ignore_invalid_triggers=None, **kwargs):
""" Calls the base method and regenerates all models's graphs. """
Expand Down

0 comments on commit d3363b9

Please sign in to comment.