Skip to content

Commit

Permalink
increase coverage; cleanup
Browse files Browse the repository at this point in the history
test has_queue; remove unused code in nesting;
  • Loading branch information
aleneum committed Mar 19, 2020
1 parent d3363b9 commit fd79f08
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 45 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1283,6 +1283,14 @@ While you can reference substates as done in `['go', '2_z', '2_x']` you cannot r
When a parent state is exited, its children will also be exited.
In addition to the processing order of transitions known from `Machine` where transitions are considered in the order they were added, `HierarchicalMachine` considers hierarchy as well.
Transitions defined in substates will be evaluated first (e.g. `C_1_a` is left before `C_2_z`) and transitions defined with wildcard `*` will (for now) only add transitions to root states (in this example `A`, `B`, `C`)
Starting with *0.8.0* nested states can be added directly and will issue the creation of parent states on-the-fly:

```python
m = HierarchicalMachine(states=['A'], initial='A')
m.add_state('B_1_a')
m.to_B_1()
assert m.is_B(allow_substates=True)
```


#### Reuse of previously created HSMs
Expand Down
31 changes: 29 additions & 2 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,18 @@ def test_async_callback_event_data(self):
def sync_condition(event_data):
return event_data.state == state_a

async def process(event_data):
async def async_conditions(event_data):
return event_data.state == state_a

async def async_callback(event_data):
self.assertEqual(event_data.state, state_b)

def sync_callback(event_data):
self.assertEqual(event_data.state, state_b)

m = self.machine_cls(states=[state_a, state_b], initial='A', send_event=True)
m.add_transition('go', 'A', 'B', conditions=sync_condition, after=process)
m.add_transition('go', 'A', 'B', conditions=[sync_condition, async_conditions],
after=[sync_callback, async_callback])
m.add_transition('go', 'B', 'A', conditions=sync_condition)
asyncio.run(m.go())
self.assertTrue(m.is_B())
Expand Down Expand Up @@ -175,6 +182,16 @@ async def change_state(machine):
with self.assertRaises(MachineError):
await machine.run(machine=machine)

async def raise_machine_error(event_data):
self.assertTrue(event_data.machine.has_queue)
await event_data.model.to_A()
event_data.machine._queued = False
await event_data.model.to_C()

async def raise_exception(event_data):
await event_data.model.to_C()
raise ValueError("Clears queue")

transitions = [
{'trigger': 'walk', 'source': 'A', 'dest': 'B', 'before': change_state},
{'trigger': 'run', 'source': 'B', 'dest': 'C'},
Expand All @@ -187,6 +204,16 @@ async def change_state(machine):
m = self.machine_cls(states=states, transitions=transitions, initial='A', queued=True)
asyncio.run(m.walk(machine=m))
self.assertEqual(m.state, 'C')
m = self.machine_cls(states=states, initial='A', queued=True, send_event=True,
before_state_change=raise_machine_error)
with self.assertRaises(MachineError):
asyncio.run(m.to_C())
m = self.machine_cls(states=states, initial='A', queued=True, send_event=True)
m.add_transition('go', 'A', 'B', after='go')
m.add_transition('go', 'B', 'C', before=raise_exception)
with self.assertRaises(ValueError):
asyncio.run(m.go())
self.assertEqual('B', m.state)


class AsyncGraphMachine(TestAsync):
Expand Down
7 changes: 7 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ def on_exit_B(event):


class TestTransitions(TestCase):

def setUp(self):
self.stuff = Stuff()
self.machine_cls = Machine

def tearDown(self):
pass
Expand Down Expand Up @@ -1082,6 +1084,11 @@ def __init__(self):
self.assertEqual(instance.state_b, 'A')
self.assertEqual(instance.state_a, 'B')

def test_initial_not_registered(self):
m1 = self.machine_cls(states=['A', 'B'], initial=self.machine_cls.state_cls('C'))
self.assertTrue(m1.is_C())
self.assertTrue('C' in m1.states)


class TestWarnings(TestCase):
def test_warning(self):
Expand Down
5 changes: 5 additions & 0 deletions tests/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,8 @@ def test_nested_enums(self):
self.assertEqual(m1.state, m2.state)
m1.to_A()
self.assertNotEqual(m1.state, m2.state)

def test_initial_enum(self):
m1 = self.machine_cls(states=self.States, initial=self.States.GREEN)
self.assertEqual(self.States.GREEN, m1.state)
self.assertEqual(m1.state.name, self.States.GREEN.name)
41 changes: 36 additions & 5 deletions tests/test_nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from unittest import skipIf
from .test_core import TestTransitions as TestsCore
from .utils import Stuff
from .utils import Stuff, DummyModel

try:
from unittest.mock import MagicMock
Expand Down Expand Up @@ -88,7 +88,8 @@ def test_init_machine_with_nested_states(self):
b = State('B')
b_1 = State('1')
b_2 = State('2')
b.add_substates([b_1, b_2])
b.add_substate(b_1)
b.add_substates([b_2])
m = self.stuff.machine_cls(states=[a, b])
self.assertEqual(m.states['B'].states['1'], b_1)
m.to("B{0}1".format(State.separator))
Expand Down Expand Up @@ -219,6 +220,16 @@ def test_add_custom_state(self):
s.run()
self.assertEqual('C{0}3{0}a'.format(State.separator), s.state)

def test_add_nested_state(self):
m = self.machine_cls(states=['A'], initial='A')
m.add_state('B{0}1{0}a'.format(self.state_cls.separator))
self.assertIn('B', m.states)
self.assertIn('1', m.states['B'].states)
self.assertIn('a', m.states['B'].states['1'].states)

with self.assertRaises(ValueError):
m.add_state(m.states['A'])

def test_enter_exit_nested_state(self):
State = self.state_cls
mock = MagicMock()
Expand Down Expand Up @@ -562,6 +573,26 @@ def test_internal_transitions(self):
self.assertEqual(s.state, 'A')
self.assertEqual(s.level, 2)

def test_transition_with_unknown_state(self):
s = self.stuff
with self.assertRaises(ValueError):
s.machine.add_transition('next', 'A', s.machine.state_cls('X'))

def test_skip_to_override(self):
mock = MagicMock()
class Model:

def to(self):
mock()

model1 = Model()
model2 = DummyModel()
machine = self.machine_cls([model1, model2], states=['A', 'B'], initial='A')
model1.to()
model2.to('B')
self.assertTrue(mock.called)
self.assertTrue(model2.is_B())


@skipIf(pgv is None, 'NestedGraph diagram test requires graphviz')
class TestWithGraphTransitions(TestTransitions):
Expand All @@ -570,9 +601,9 @@ def setUp(self):
states = ['A', 'B', {'name': 'C', 'children': ['1', '2', {'name': '3', 'children': ['a', 'b', 'c']}]},
'D', 'E', 'F']

machine_cls = MachineFactory.get_predefined(graph=True, nested=True)
self.state_cls = machine_cls.state_cls
self.stuff = Stuff(states, machine_cls)
self.machine_cls = MachineFactory.get_predefined(graph=True, nested=True)
self.state_cls = self.machine_cls.state_cls
self.stuff = Stuff(states, self.machine_cls)

def test_ordered_with_graph(self):
State = self.state_cls
Expand Down
3 changes: 3 additions & 0 deletions tests/test_nesting_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def test_transitioning(self):
def test_nested_definitions(self):
pass

def test_add_nested_state(self):
pass # not supported by legacy machine


class TestReuseLegacy(TestReuse):

Expand Down
6 changes: 4 additions & 2 deletions tests/test_threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def heavy_checking():
class TestLockedTransitions(TestCore):

def setUp(self):
self.stuff = Stuff(machine_cls=MachineFactory.get_predefined(locked=True))
self.machine_cls = MachineFactory.get_predefined(locked=True)
self.stuff = Stuff(machine_cls=self.machine_cls)
self.stuff.heavy_processing = heavy_processing
self.stuff.machine.add_transition('forward', 'A', 'B', before='heavy_processing')

Expand Down Expand Up @@ -175,7 +176,8 @@ def setUp(self):
self.c3 = SomeContext(event_list=self.event_list)
self.c4 = SomeContext(event_list=self.event_list)

self.stuff = Stuff(machine_cls=MachineFactory.get_predefined(locked=True), extra_kwargs={
self.machine_cls = MachineFactory.get_predefined(locked=True)
self.stuff = Stuff(machine_cls=self.machine_cls, extra_kwargs={
'machine_context': [self.c1, self.c2]
})
self.stuff.machine.add_model(self.s1, model_context=[self.c3, self.c4])
Expand Down
2 changes: 0 additions & 2 deletions transitions/extensions/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,6 @@ async def _process(self, trigger):
else:
raise MachineError("Attempt to process events synchronously while transition queue is not empty!")

# ToDo: test for has_queue
# process queued events
self._transition_queue.append(trigger)
# another entry in the queue implies a running transition; skip immediate execution
if len(self._transition_queue) > 1:
Expand Down
35 changes: 5 additions & 30 deletions transitions/extensions/nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,32 +506,19 @@ def get_global_name(self, state=None, join=True):
if state:
if isinstance(state, State):
state = state.name
domains = self._get_global_name(state, domains)
if state in self.states:
domains.append(state)
else:
raise ValueError("State '{0}' not found in local states.".format(state))
return self.state_cls.separator.join(domains) if join else domains

def get_global_state(self, state):
states = self._stack[0][1] if self._stack else self.states
domains = state
for sco in domains:
states = states[sco].states
return states

def get_local_name(self, state_name, join=True):
if isinstance(state_name, Enum):
state_name = state_name.name
elif isinstance(state_name, State):
if state_name == self.scoped:
return '' if join else []
state_name = self.get_global_name(state_name)
state_name = state_name.split(self.state_cls.separator)
local_stack = [s[0] for s in self._stack] + [self.scoped]
local_stack_start = len(local_stack) - local_stack[::-1].index(self)
domains = [s.name for s in local_stack[local_stack_start:]]
if domains and state_name and state_name[0] != domains[0]:
return self.state_cls.separator.join(state_name) if join else state_name
while domains and state_name and state_name[0] == domains[0]:
state_name.pop(0)
domains.pop(0)
return self.state_cls.separator.join(state_name) if join else state_name

def get_nested_state_names(self):
Expand Down Expand Up @@ -715,18 +702,6 @@ def _add_trigger_to_model(self, trigger, model):
else:
self._checked_assignment(model, trigger, trig_func)

def _get_global_name(self, state=None, domains=[]):
domains.append(state)
if state in self.states:
return domains
else:
for child in self.states:
with self(child):
domains = self._get_global_name(state, domains)
if domains:
return [child] + domains
return domains

def _get_trigger(self, model, trigger_name, *args, **kwargs):
"""Convenience function added to the model to trigger events by name.
Args:
Expand Down Expand Up @@ -757,7 +732,7 @@ def _has_state(self, state, raise_error=False):
if not found:
for a_state in self.states:
with self(a_state):
if self.has_state(state):
if self._has_state(state):
return True
if not found and raise_error:
msg = 'State %s has not been added to the machine' % (state.name if hasattr(state, 'name') else state)
Expand Down
5 changes: 1 addition & 4 deletions transitions/extensions/nesting_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,7 @@ def parent(self, value):
def initial(self):
""" When this state is entered it will automatically enter
the child with this name if not None. """
try:
return self.name + NestedState.separator + self._initial if self._initial else self._initial
except TypeError: # we assume an Enum here
return self.name + NestedState.separator + self._initial.name
return self.name + NestedState.separator + self._initial if self._initial else self._initial

@initial.setter
def initial(self, value):
Expand Down

0 comments on commit fd79f08

Please sign in to comment.