diff --git a/docs/processing_model.md b/docs/processing_model.md index d699c1e4..56891939 100644 --- a/docs/processing_model.md +++ b/docs/processing_model.md @@ -77,4 +77,3 @@ after 'connection_succeed' from 'connecting' to 'connected' ```{note} Note that the events `connect` and `connection_succeed` are executed sequentially, and the `connect.after` runs on the expected order. ``` - diff --git a/docs/transitions.md b/docs/transitions.md index 1752224c..cfebbba8 100644 --- a/docs/transitions.md +++ b/docs/transitions.md @@ -84,7 +84,7 @@ Syntax: >>> draft = State("Draft") >>> draft.to.itself() -TransitionList([Transition(State('Draft', ... +TransitionList([Transition('Draft', 'Draft', event='', internal=False)]) ``` @@ -101,7 +101,7 @@ Syntax: >>> draft = State("Draft") >>> draft.to.itself(internal=True) -TransitionList([Transition(State('Draft', ... +TransitionList([Transition('Draft', 'Draft', event='', internal=True)]) ``` diff --git a/pyproject.toml b/pyproject.toml index 892485b0..0a8c7d3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -170,6 +170,7 @@ select = [ ignore = [ "UP006", # `use-pep585-annotation` Requires Python3.9+ "UP035", # `use-pep585-annotation` Requires Python3.9+ + "UP037", # `use-pep586-annotation` Requires Python3.9+ "UP038", # `use-pep585-annotation` Requires Python3.9+ ] diff --git a/statemachine/engines/async_.py b/statemachine/engines/async_.py index 75636378..df056597 100644 --- a/statemachine/engines/async_.py +++ b/statemachine/engines/async_.py @@ -5,7 +5,6 @@ from ..event_data import EventData from ..event_data import TriggerData from ..exceptions import TransitionNotAllowed -from ..i18n import _ from .base import BaseEngine if TYPE_CHECKING: @@ -46,7 +45,7 @@ async def processing_loop(self): first_result = self._sentinel try: # Execute the triggers in the queue in FIFO order until the queue is empty - while self._running and not self.empty(): + while self.running and not self.empty(): trigger_data = self.pop() current_time = time() if trigger_data.execution_time > current_time: diff --git a/statemachine/engines/base.py b/statemachine/engines/base.py index 39c1184c..cd983577 100644 --- a/statemachine/engines/base.py +++ b/statemachine/engines/base.py @@ -1,16 +1,19 @@ +from dataclasses import dataclass from itertools import chain from queue import PriorityQueue from queue import Queue from threading import Lock from typing import TYPE_CHECKING +from typing import Callable +from typing import List from weakref import proxy -from statemachine.orderedset import OrderedSet - from ..event import BoundEvent +from ..event import Event from ..event_data import EventData from ..event_data import TriggerData from ..exceptions import TransitionNotAllowed +from ..orderedset import OrderedSet from ..state import State from ..transition import Transition @@ -18,11 +21,21 @@ from ..statemachine import StateMachine +@dataclass(frozen=True, unsafe_hash=True) +class StateTransition: + transition: Transition + source: "State | None" = None + target: "State | None" = None + + class EventQueue: def __init__(self): self.queue: Queue = PriorityQueue() - def empty(self): + def __repr__(self): + return f"EventQueue({self.queue.queue!r}, size={self.queue.qsize()})" + + def is_empty(self): return self.queue.qsize() == 0 def put(self, trigger_data: TriggerData): @@ -54,18 +67,21 @@ def __init__(self, sm: "StateMachine"): self.external_queue = EventQueue() self.internal_queue = EventQueue() self._sentinel = object() - self._running = True + self.running = True self._processing = Lock() def empty(self): - return self.external_queue.empty() + return self.external_queue.is_empty() - def put(self, trigger_data: TriggerData): + def put(self, trigger_data: TriggerData, internal: bool = False): """Put the trigger on the queue without blocking the caller.""" - if not self._running and not self.sm.allow_event_without_transition: + if not self.running and not self.sm.allow_event_without_transition: raise TransitionNotAllowed(trigger_data.event, self.sm.current_state) - self.external_queue.put(trigger_data) + if internal: + self.internal_queue.put(trigger_data) + else: + self.external_queue.put(trigger_data) def pop(self): return self.external_queue.pop() @@ -87,3 +103,466 @@ def _initial_transition(self, trigger_data): transition = Transition(State(), self.sm._get_initial_state(), event="__initial__") transition._specs.clear() return transition + + def select_eventless_transitions(self, trigger_data: TriggerData): + """ + Select the eventless transitions that match the trigger data. + """ + return self._select_transitions(trigger_data, lambda t, _e: t.is_eventless) + + def _conditions_match(self, transition: Transition, trigger_data: TriggerData): + event_data = EventData(trigger_data=trigger_data, transition=transition) + args, kwargs = event_data.args, event_data.extended_kwargs + + self.sm._callbacks.call(transition.validators.key, *args, **kwargs) + return self.sm._callbacks.all(transition.cond.key, *args, **kwargs) + + def _filter_conflicting_transitions(self, transitions: OrderedSet[Transition]): + """ + Remove transições conflitantes, priorizando aquelas com estados de origem descendentes + ou que aparecem antes na ordem do documento. + + Args: + transitions (OrderedSet[Transition]): Conjunto de transições habilitadas. + + Returns: + OrderedSet[Transition]: Conjunto de transições sem conflitos. + """ + filtered_transitions = OrderedSet() + + # Ordena as transições na ordem dos estados que as selecionaram + for t1 in transitions: + t1_preempted = False + transitions_to_remove = OrderedSet() + + # Verifica conflitos com as transições já filtradas + for t2 in filtered_transitions: + # Calcula os conjuntos de saída (exit sets) + t1_exit_set = self._compute_exit_set(t1) + t2_exit_set = self._compute_exit_set(t2) + + # Verifica interseção dos conjuntos de saída + if t1_exit_set & t2_exit_set: # Há interseção + if t1.source.is_descendant(t2.source): + # t1 é preferido pois é descendente de t2 + transitions_to_remove.add(t2) + else: + # t2 é preferido pois foi selecionado antes na ordem do documento + t1_preempted = True + break + + # Se t1 não foi preemptado, adiciona a lista filtrada e remove os conflitantes + if not t1_preempted: + for t3 in transitions_to_remove: + filtered_transitions.discard(t3) + filtered_transitions.add(t1) + + return filtered_transitions + + def _compute_exit_set(self, transitions: List[Transition]) -> OrderedSet[StateTransition]: + """Compute the exit set for a transition.""" + + states_to_exit = OrderedSet() + + for transition in transitions: + if transition.target is None: + continue + domain = self.get_transition_domain(transition) + for state in self.sm.configuration: + if domain is None or state.is_descendant(domain): + info = StateTransition( + transition=transition, source=state, target=transition.target + ) + states_to_exit.add(info) + + return states_to_exit + + def get_transition_domain(self, transition: Transition) -> "State | None": + """ + Return the compound state such that + 1) all states that are exited or entered as a result of taking 'transition' are + descendants of it + 2) no descendant of it has this property. + """ + states = self.get_effective_target_states(transition) + if not states: + return None + elif ( + transition.internal + and transition.source.is_compound + and all(state.is_descendant(transition.source) for state in states) + ): + return transition.source + else: + return self.find_lcca([transition.source] + list(states)) + + @staticmethod + def find_lcca(states: List[State]) -> "State | None": + """ + Find the Least Common Compound Ancestor (LCCA) of the given list of states. + + Args: + state_list: A list of states. + + Returns: + The LCCA state, which is a proper ancestor of all states in the list, + or None if no such ancestor exists. + """ + # Get ancestors of the first state in the list, filtering for compound or SCXML elements + head, *tail = states + ancestors = [anc for anc in head.ancestors() if anc.is_compound] + + # Find the first ancestor that is also an ancestor of all other states in the list + for ancestor in ancestors: + if all(state.is_descendant(ancestor) for state in tail): + return ancestor + + return None + + def get_effective_target_states(self, transition: Transition) -> OrderedSet[State]: + # TODO: Handle history states + return OrderedSet([transition.target]) + + def select_transitions(self, trigger_data: TriggerData) -> OrderedSet[Transition]: + """ + Select the transitions that match the trigger data. + """ + return self._select_transitions(trigger_data, lambda t, e: t.match(e)) + + def _select_transitions( + self, trigger_data: TriggerData, predicate: Callable + ) -> OrderedSet[Transition]: + """Select the transitions that match the trigger data.""" + enabled_transitions = OrderedSet() + + # Get atomic states, TODO: sorted by document order + atomic_states = (state for state in self.sm.configuration if state.is_atomic) + + def first_transition_that_matches(state: State, event: Event) -> "Transition | None": + for s in chain([state], state.ancestors()): + for transition in s.transitions: + if predicate(transition, event) and self._conditions_match( + transition, trigger_data + ): + return transition + + for state in atomic_states: + transition = first_transition_that_matches(state, trigger_data.event) + if transition is not None: + enabled_transitions.add(transition) + + return self._filter_conflicting_transitions(enabled_transitions) + + def microstep(self, transitions: List[Transition], trigger_data: TriggerData): + """Process a single set of transitions in a 'lock step'. + This includes exiting states, executing transition content, and entering states. + """ + result = self._execute_transition_content( + transitions, trigger_data, lambda t: t.before.key + ) + + states_to_exit = self._exit_states(transitions, trigger_data) + result += self._execute_transition_content(transitions, trigger_data, lambda t: t.on.key) + self._enter_states(transitions, trigger_data, states_to_exit) + self._execute_transition_content( + transitions, + trigger_data, + lambda t: t.after.key, + set_target_as_state=True, + ) + + if len(result) == 0: + result = None + elif len(result) == 1: + result = result[0] + + return result + + def _exit_states(self, enabled_transitions: List[Transition], trigger_data: TriggerData): + """Compute and process the states to exit for the given transitions.""" + states_to_exit = self._compute_exit_set(enabled_transitions) + + # # TODO: Remove states from states_to_invoke + # for state in states_to_exit: + # self.states_to_invoke.discard(state) + + # TODO: Sort states to exit in exit order + # states_to_exit = sorted(states_to_exit, key=self.exit_order) + + for info in states_to_exit: + event_data = EventData(trigger_data=trigger_data, transition=info.transition) + args, kwargs = event_data.args, event_data.extended_kwargs + + # # TODO: Update history + # for history in state.history: + # if history.type == "deep": + # history_value = [s for s in self.sm.configuration if self.is_descendant(s, state)] # noqa: E501 + # else: # shallow history + # history_value = [s for s in self.sm.configuration if s.parent == state] + # self.history_values[history.id] = history_value + + # Execute `onexit` handlers + if info.source is not None and not info.transition.internal: + self.sm._callbacks.call(info.source.exit.key, *args, **kwargs) + + # TODO: Cancel invocations + # for invocation in state.invoke: + # self.cancel_invoke(invocation) + + # Remove state from configuration + # self.sm.configuration -= {info.source} # .discard(info.source) + + return OrderedSet([info.source for info in states_to_exit]) + + def _execute_transition_content( + self, + enabled_transitions: List[Transition], + trigger_data: TriggerData, + get_key: Callable[[Transition], str], + set_target_as_state: bool = False, + ): + result = [] + for transition in enabled_transitions: + event_data = EventData(trigger_data=trigger_data, transition=transition) + if set_target_as_state: + event_data.state = transition.target + args, kwargs = event_data.args, event_data.extended_kwargs + + result += self.sm._callbacks.call(get_key(transition), *args, **kwargs) + + return result + + def _enter_states( + self, + enabled_transitions: List[Transition], + trigger_data: TriggerData, + states_to_exit: OrderedSet[State], + ): + """Enter the states as determined by the given transitions.""" + states_to_enter = OrderedSet[StateTransition]() + states_for_default_entry = OrderedSet[StateTransition]() + default_history_content = {} + + # Compute the set of states to enter + self.compute_entry_set( + enabled_transitions, states_to_enter, states_for_default_entry, default_history_content + ) + + # We update the configuration atomically + states_targets_to_enter = OrderedSet(info.target for info in states_to_enter) + configuration = self.sm.configuration + self.sm.configuration = (configuration - states_to_exit) | states_targets_to_enter + + # Sort states to enter in entry order + # for state in sorted(states_to_enter, key=self.entry_order): # TODO: ordegin of states_to_enter # noqa: E501 + for info in states_to_enter: + target = info.target + transition = info.transition + event_data = EventData(trigger_data=trigger_data, transition=transition) + event_data.state = target + args, kwargs = event_data.args, event_data.extended_kwargs + + # Add state to the configuration + # self.sm.configuration |= {target} + + # TODO: Add state to states_to_invoke + # self.states_to_invoke.add(state) + + # Initialize data model if using late binding + # if self.binding == "late" and state.is_first_entry: + # self.initialize_data_model(state) + # state.is_first_entry = False + + # Execute `onentry` handlers + if not transition.internal: + self.sm._callbacks.call(target.enter.key, *args, **kwargs) + + # Handle default initial states + # TODO: Handle default initial states + # if state in states_for_default_entry: + # self.execute_content(state.initial.transition) + + # Handle default history states + # if state.id in default_history_content: + # self.execute_content(default_history_content[state.id]) + + # Handle final states + if target.final: + if target.parent is None: + self.running = False + else: + parent = target.parent + grandparent = parent.parent + + self.internal_queue.put(BoundEvent(f"done.state.{parent.id}", _sm=self.sm)) + if grandparent.parallel: + if all(child.final for child in grandparent.states): + self.internal_queue.put( + BoundEvent(f"done.state.{parent.id}", _sm=self.sm) + ) + + def compute_entry_set( + self, transitions, states_to_enter, states_for_default_entry, default_history_content + ): + """ + Compute the set of states to be entered based on the given transitions. + + Args: + transitions: A list of transitions. + states_to_enter: A set to store the states that need to be entered. + states_for_default_entry: A set to store compound states requiring default entry + processing. + default_history_content: A dictionary to hold temporary content for history states. + """ + for transition in transitions: + # Process each target state of the transition + for target_state in [transition.target]: + info = StateTransition( + transition=transition, target=target_state, source=transition.source + ) + self.add_descendant_states_to_enter( + info, states_to_enter, states_for_default_entry, default_history_content + ) + + # Determine the ancestor state (transition domain) + ancestor = self.get_transition_domain(transition) + + # Add ancestor states to enter for each effective target state + for effective_target in self.get_effective_target_states(transition): + info = StateTransition( + transition=transition, target=effective_target, source=transition.source + ) + self.add_ancestor_states_to_enter( + info, + ancestor, + states_to_enter, + states_for_default_entry, + default_history_content, + ) + + def add_descendant_states_to_enter( + self, + info: StateTransition, + states_to_enter, + states_for_default_entry, + default_history_content, + ): + """ + Add the given state and its descendants to the entry set. + + Args: + state: The state to add to the entry set. + states_to_enter: A set to store the states that need to be entered. + states_for_default_entry: A set to track compound states requiring default entry + processing. + default_history_content: A dictionary to hold temporary content for history states. + """ + # if self.is_history_state(state): + # # Handle history state + # if state.id in self.history_values: + # for history_state in self.history_values[state.id]: + # self.add_descendant_states_to_enter(history_state, states_to_enter, states_for_default_entry, default_history_content) # noqa: E501 + # for history_state in self.history_values[state.id]: + # self.add_ancestor_states_to_enter(history_state, state.parent, states_to_enter, states_for_default_entry, default_history_content) # noqa: E501 + # else: + # # Handle default history content + # default_history_content[state.parent.id] = state.transition.content + # for target_state in state.transition.target: + # self.add_descendant_states_to_enter(target_state, states_to_enter, states_for_default_entry, default_history_content) # noqa: E501 + # for target_state in state.transition.target: + # self.add_ancestor_states_to_enter(target_state, state.parent, states_to_enter, states_for_default_entry, default_history_content) # noqa: E501 + # return + + # Add the state to the entry set + states_to_enter.add(info) + state = info.target + + if state.is_compound: + # Handle compound states + states_for_default_entry.add(info) + initial_state = next(s for s in state.states if s.initial) + for transition in initial_state.transitions: + info_initial = StateTransition( + transition=transition, + target=transition.target, + source=transition.source, + ) + self.add_descendant_states_to_enter( + info_initial, + states_to_enter, + states_for_default_entry, + default_history_content, + ) + for transition in initial_state.transitions: + info_initial = StateTransition( + transition=transition, + target=transition.target, + source=transition.source, + ) + self.add_ancestor_states_to_enter( + info_initial, + state, + states_to_enter, + states_for_default_entry, + default_history_content, + ) + elif state.parallel: + # Handle parallel states + for child_state in state.states: + if not any(s.target.is_descendant(child_state) for s in states_to_enter): + info_to_add = StateTransition( + transition=info.transition, + target=child_state, + source=info.transition.source, + ) + self.add_descendant_states_to_enter( + info_to_add, + states_to_enter, + states_for_default_entry, + default_history_content, + ) + + def add_ancestor_states_to_enter( + self, + info: StateTransition, + ancestor, + states_to_enter, + states_for_default_entry, + default_history_content, + ): + """ + Add ancestors of the given state to the entry set. + + Args: + state: The state whose ancestors are to be added. + ancestor: The upper bound ancestor (exclusive) to stop at. + states_to_enter: A set to store the states that need to be entered. + states_for_default_entry: A set to track compound states requiring default entry + processing. + default_history_content: A dictionary to hold temporary content for history states. + """ + state = info.target + for anc in state.ancestors(parent=ancestor): + # Add the ancestor to the entry set + info_to_add = StateTransition( + transition=info.transition, + target=anc, + source=info.transition.source, + ) + states_to_enter.add(info_to_add) + + if anc.parallel: + # Handle parallel states + for child in anc.states: + if not any(s.target.is_descendant(child) for s in states_to_enter): + info_to_add = StateTransition( + transition=info.transition, + target=child, + source=info.transition.source, + ) + self.add_descendant_states_to_enter( + child, + states_to_enter, + states_for_default_entry, + default_history_content, + ) diff --git a/statemachine/engines/sync.py b/statemachine/engines/sync.py index e1d9dfcf..bd47435a 100644 --- a/statemachine/engines/sync.py +++ b/statemachine/engines/sync.py @@ -2,9 +2,9 @@ from time import time from typing import TYPE_CHECKING +from statemachine.event import BoundEvent from statemachine.orderedset import OrderedSet -from ..event_data import EventData from ..event_data import TriggerData from ..exceptions import TransitionNotAllowed from .base import BaseEngine @@ -15,7 +15,9 @@ class SyncEngine(BaseEngine): def start(self): - super().start() + if self.sm.current_state_value is not None: + return + self.activate_initial_state() def activate_initial_state(self): @@ -28,9 +30,17 @@ def activate_initial_state(self): Given how async works on python, there's no built-in way to activate the initial state that may depend on async code from the StateMachine.__init__ method. """ + if self.sm.current_state_value is None: + trigger_data = BoundEvent("__initial__", _sm=self.sm).build_trigger(machine=self.sm) + transition = self._initial_transition(trigger_data) + self._processing.acquire(blocking=False) + try: + self._enter_states([transition], trigger_data, {}) + finally: + self._processing.release() return self.processing_loop() - def processing_loop(self): + def processing_loop(self): # noqa: C901 """Process event triggers. The event is put on a queue, and only the first event will have the result collected. @@ -50,86 +60,84 @@ def processing_loop(self): # be also `None`, and on this case the `first_result` may be overridden by another result. first_result = self._sentinel try: - # Execute the triggers in the queue in FIFO order until the queue is empty - while self._running and not self.empty(): - trigger_data = self.pop() - current_time = time() - if trigger_data.execution_time > current_time: - self.put(trigger_data) - sleep(0.001) - continue - try: - result = self._trigger(trigger_data) - if first_result is self._sentinel: - first_result = result - except Exception: - # Whe clear the queue as we don't have an expected behavior - # and cannot keep processing - self.clear() - raise + took_events = True + while took_events: + took_events = False + # Execute the triggers in the queue in FIFO order until the queue is empty + # while self._running and not self.external_queue.is_empty(): + macrostep_done = False + enabled_transitions: "OrderedSet[Transition] | None" = None + + # handles eventless transitions and internal events + while not macrostep_done: + internal_event = TriggerData( + self.sm, event=None + ) # this one is a "null object" + enabled_transitions = self.select_eventless_transitions(internal_event) + if not enabled_transitions: + if self.internal_queue.is_empty(): + macrostep_done = True + else: + internal_event = self.internal_queue.pop() + + enabled_transitions = self.select_transitions(internal_event) + if enabled_transitions: + took_events = True + self.microstep(list(enabled_transitions), internal_event) + + # TODO: Invoke platform-specific logic + # for state in sorted(self.states_to_invoke, key=self.entry_order): + # for inv in sorted(state.invoke, key=self.document_order): + # self.invoke(inv) + # self.states_to_invoke.clear() + + # Process remaining internal events before external events + while not self.internal_queue.is_empty(): + internal_event = self.internal_queue.pop() + enabled_transitions = self.select_transitions(internal_event) + if enabled_transitions: + self.microstep(list(enabled_transitions)) + + # Process external events + while not self.external_queue.is_empty(): + took_events = True + external_event = self.external_queue.pop() + current_time = time() + if external_event.execution_time > current_time: + self.put(external_event) + sleep(0.001) + continue + + # # TODO: Handle cancel event + # if self.is_cancel_event(external_event): + # self.running = False + # return + + # TODO: Invoke states + # for state in self.configuration: + # for inv in state.invoke: + # if inv.invokeid == external_event.invokeid: + # self.apply_finalize(inv, external_event) + # if inv.autoforward: + # self.send(inv.id, external_event) + + enabled_transitions = self.select_transitions(external_event) + if enabled_transitions: + try: + result = self.microstep(list(enabled_transitions), external_event) + if first_result is self._sentinel: + first_result = result + + except Exception: + # Whe clear the queue as we don't have an expected behavior + # and cannot keep processing + self.clear() + raise + + else: + if not self.sm.allow_event_without_transition: + raise TransitionNotAllowed(external_event.event, self.sm.current_state) + finally: self._processing.release() return first_result if first_result is not self._sentinel else None - - def _trigger(self, trigger_data: TriggerData): # noqa: C901 - executed = False - if trigger_data.event == "__initial__": - transition = self._initial_transition(trigger_data) - self._activate(trigger_data, transition) - if self.sm.current_state.transitions.has_eventless_transition: - self.put(TriggerData(self.sm, event=None)) - return self._sentinel - - state = self.sm.current_state - for transition in state.transitions: - if not transition.match(trigger_data.event): - continue - - executed, result = self._activate(trigger_data, transition) - if not executed: - continue - - if self.sm.current_state.transitions.has_eventless_transition: - self.put(TriggerData(self.sm, event=None)) - break - else: - if not self.sm.allow_event_without_transition: - raise TransitionNotAllowed(trigger_data.event, state) - - return result if executed else None - - def _activate(self, trigger_data: TriggerData, transition: "Transition"): # noqa: C901 - event_data = EventData(trigger_data=trigger_data, transition=transition) - args, kwargs = event_data.args, event_data.extended_kwargs - - self.sm._callbacks.call(transition.validators.key, *args, **kwargs) - if not self.sm._callbacks.all(transition.cond.key, *args, **kwargs): - return False, None - - source = transition.source - target = transition.target - - result = self.sm._callbacks.call(transition.before.key, *args, **kwargs) - if source is not None and not transition.internal: - self.sm._callbacks.call(source.exit.key, *args, **kwargs) - - result += self.sm._callbacks.call(transition.on.key, *args, **kwargs) - - self.sm.configuration = OrderedSet([target]) - event_data.state = target - kwargs["state"] = target - - if not transition.internal: - self.sm._callbacks.call(target.enter.key, *args, **kwargs) - self.sm._callbacks.call(transition.after.key, *args, **kwargs) - - if target.final: - self.clear() - self._running = False - - if len(result) == 0: - result = None - elif len(result) == 1: - result = result[0] - - return True, result diff --git a/statemachine/event.py b/statemachine/event.py index 3eaa42c3..d67df060 100644 --- a/statemachine/event.py +++ b/statemachine/event.py @@ -45,6 +45,9 @@ class Event(AddCallbacksMixin, str): delay: float = 0 """The delay in milliseconds before the event is triggered. Default is 0.""" + internal: bool = False + """Indicates if the events should be placed on the internal event queue.""" + _sm: "StateMachine | None" = None """The state machine instance.""" @@ -57,6 +60,7 @@ def __new__( id: "str | None" = None, name: "str | None" = None, delay: float = 0, + internal: bool = False, _sm: "StateMachine | None" = None, ): if isinstance(transitions, str): @@ -69,6 +73,7 @@ def __new__( instance = super().__new__(cls, id) instance.id = id instance.delay = delay + instance.internal = internal if name: instance.name = name elif _has_real_id: @@ -82,7 +87,9 @@ def __new__( return instance def __repr__(self): - return f"{type(self).__name__}({self.id!r})" + return ( + f"{type(self).__name__}({self.id!r}, delay={self.delay!r}, internal={self.internal!r})" + ) def is_same_event(self, *_args, event: "str | None" = None, **_kwargs) -> bool: return self == event @@ -116,6 +123,13 @@ def put(self, *args, machine: "StateMachine", send_id: "str | None" = None, **kw # can be called as a method. But it is not meant to be called without # an SM instance. Such SM instance is provided by `__get__` method when # used as a property descriptor. + trigger_data = self.build_trigger(*args, machine=machine, send_id=send_id, **kwargs) + machine._put_nonblocking(trigger_data, internal=self.internal) + return trigger_data + + def build_trigger( + self, *args, machine: "StateMachine", send_id: "str | None" = None, **kwargs + ): if machine is None: raise RuntimeError(_("Event {} cannot be called without a SM instance").format(self)) @@ -127,7 +141,7 @@ def put(self, *args, machine: "StateMachine", send_id: "str | None" = None, **kw args=args, kwargs=kwargs, ) - machine._put_nonblocking(trigger_data) + return trigger_data def __call__(self, *args, **kwargs): diff --git a/statemachine/event_data.py b/statemachine/event_data.py index d60f43fb..4c66c956 100644 --- a/statemachine/event_data.py +++ b/statemachine/event_data.py @@ -25,6 +25,8 @@ class TriggerData: Allow revoking a delayed :ref:`TriggerData` instance. """ + _target: "str | None" = field(init=False, compare=False, default=None) + execution_time: float = field(default=0.0) """The time at which the :ref:`Event` should run.""" @@ -62,10 +64,6 @@ class EventData: target: "State" = field(init=False) """The destination :ref:`State` of the :ref:`transition`.""" - result: "Any | None" = None - - executed: bool = False - def __post_init__(self): self.state = self.transition.source self.source = self.transition.source diff --git a/statemachine/factory.py b/statemachine/factory.py index 8b80e516..8e72107c 100644 --- a/statemachine/factory.py +++ b/statemachine/factory.py @@ -246,9 +246,9 @@ def _add_unbounded_callback(cls, attr_name, func): def add_state(cls, id, state: State): state._set_id(id) + cls.states_map[state.value] = state if not state.parent: cls.states.append(state) - cls.states_map[state.value] = state if not hasattr(cls, id): setattr(cls, id, state) diff --git a/statemachine/io/scxml/actions.py b/statemachine/io/scxml/actions.py index bf17d0fe..9babfd45 100644 --- a/statemachine/io/scxml/actions.py +++ b/statemachine/io/scxml/actions.py @@ -5,11 +5,9 @@ from itertools import chain from typing import Any from typing import Callable -from typing import List from uuid import uuid4 from statemachine.exceptions import InvalidDefinition -from statemachine.model import Model from ...event import Event from ...statemachine import StateMachine @@ -163,10 +161,15 @@ def __init__(self, cond: str, processor=None): self.processor = processor def __call__(self, *args, **kwargs): + machine = kwargs["machine"] if self.processor: kwargs["_ioprocessors"] = self.processor.wrap(**kwargs) - return _eval(self.action, **kwargs) + try: + return _eval(self.action, **kwargs) + except Exception as e: + machine.send("error.execution", error=e, internal=True) + return False @staticmethod def _normalize(cond: "str | None") -> "str | None": @@ -215,15 +218,6 @@ def create_action_callable(action: Action) -> Callable: raise ValueError(f"Unknown action type: {type(action)}") -def create_raise_action_callable(action: RaiseAction) -> Callable: - def raise_action(*args, **kwargs): - machine: StateMachine = kwargs["machine"] - machine.send(action.event) - - raise_action.action = action # type: ignore[attr-defined] - return raise_action - - class Assign(CallableAction): def __init__(self, action: AssignAction): super().__init__() @@ -238,9 +232,9 @@ def __call__(self, *args, **kwargs): for p in path: obj = getattr(obj, p) - if not attr.isidentifier(): + if not attr.isidentifier() or not hasattr(obj, attr): raise ValueError( - f" 'location' must be a valid Python attribute name, " + f" 'location' must be a valid Python attribute name and must be declared, " f"got: {self.action.location}" ) setattr(obj, attr, value) @@ -309,8 +303,21 @@ def foreach_action(*args, **kwargs): return foreach_action +def create_raise_action_callable(action: RaiseAction) -> Callable: + def raise_action(*args, **kwargs): + machine: StateMachine = kwargs["machine"] + + Event(id=action.event, name=action.event, internal=True).put( + machine=machine, + ) + + raise_action.action = action # type: ignore[attr-defined] + return raise_action + + def create_send_action_callable(action: SendAction) -> Callable: content: Any = () + _valid_targets = (None, "#_internal", "internal", "#_parent", "parent") if action.content: try: content = (eval(action.content, {}, {}),) @@ -320,11 +327,16 @@ def create_send_action_callable(action: SendAction) -> Callable: def send_action(*args, **kwargs): machine: StateMachine = kwargs["machine"] event = action.event or _eval(action.eventexpr, **kwargs) - _target = _eval(action.target, **kwargs) if action.target else None + target = action.target if action.target else None + if action.type and action.type != "http://www.w3.org/TR/scxml/#SCXMLEventProcessor": raise ValueError( "Only 'http://www.w3.org/TR/scxml/#SCXMLEventProcessor' event type is supported" ) + if target not in _valid_targets: + raise ValueError(f"Invalid target: {target}. Must be one of {_valid_targets}") + + internal = target in ("#_internal", "internal") if action.id: send_id = action.id @@ -343,7 +355,7 @@ def send_action(*args, **kwargs): for param in chain(names, action.params): params_values[param.name] = _eval(param.expr, **kwargs) - Event(id=event, name=event, delay=delay).put( + Event(id=event, name=event, delay=delay, internal=internal).put( *content, machine=machine, send_id=send_id, @@ -388,10 +400,16 @@ def script_action(*args, **kwargs): def _create_dataitem_callable(action: DataItem) -> Callable: - def data_initializer(machine: StateMachine, **kwargs): - # Evaluate the expression if provided, or set to None + def data_initializer(**kwargs): + machine: StateMachine = kwargs["machine"] + if action.expr: - value = _eval(action.expr, **kwargs) + try: + value = _eval(action.expr, **kwargs) + except Exception: + setattr(machine.model, action.id, None) + raise + elif action.content: try: value = _eval(action.content, **kwargs) @@ -412,29 +430,18 @@ def create_datamodel_action_callable(action: DataModel) -> "Callable | None": if not data_elements: return None - def __init__( - self, - model: Any = None, - state_field: str = "state", - start_value: Any = None, - allow_event_without_transition: bool = True, - listeners: "List[object] | None" = None, - ): - model = model if model else Model() - self.model = model + def datamodel(*args, **kwargs): + machine: StateMachine = kwargs["machine"] for act in data_elements: - act(machine=self) - - StateMachine.__init__( - self, - model, - state_field=state_field, - start_value=start_value, - allow_event_without_transition=allow_event_without_transition, - listeners=listeners, - ) + try: + act(machine=machine) + except Exception as e: + logger.debug("Error executing actions", exc_info=True) + if isinstance(e, InvalidDefinition): + raise + machine.send("error.execution", error=e, internal=True) - return __init__ + return datamodel class ExecuteBlock(CallableAction): @@ -456,4 +463,4 @@ def __call__(self, *args, **kwargs): logger.debug("Error executing actions", exc_info=True) if isinstance(e, InvalidDefinition): raise - machine.send("error.execution", error=e) + machine.send("error.execution", error=e, internal=True) diff --git a/statemachine/io/scxml/processor.py b/statemachine/io/scxml/processor.py index fa8f9c58..2f5b8c6a 100644 --- a/statemachine/io/scxml/processor.py +++ b/statemachine/io/scxml/processor.py @@ -35,15 +35,16 @@ def parse_scxml(self, sm_name: str, scxml_content: str): def process_definition(self, definition, location: str): states_dict = self._process_states(definition.states) - extra_data = {} - # Process datamodel (initial variables) if definition.datamodel: - __init__ = create_datamodel_action_callable(definition.datamodel) - if __init__: - extra_data["__init__"] = __init__ + datamodel = create_datamodel_action_callable(definition.datamodel) + if datamodel: + initial_state = next(s for s in iter(states_dict.values()) if s["initial"]) + if "enter" not in initial_state: + initial_state["enter"] = [] + initial_state["enter"].insert(0, datamodel) - self._add(location, {"states": states_dict, **extra_data}) + self._add(location, {"states": states_dict}) def _process_states(self, states: Dict[str, State]) -> Dict[str, StateDefinition]: states_dict: Dict[str, StateDefinition] = {} @@ -113,6 +114,7 @@ def _add(self, location: str, definition: Dict[str, Any]): ) from e def start(self, **kwargs): + kwargs["allow_event_without_transition"] = True self.root_cls = next(iter(self.scs.values())) self.root = self.root_cls(**kwargs) self.sessions[self.root.name] = self.root diff --git a/statemachine/state.py b/statemachine/state.py index 57cdef8e..b6ef358f 100644 --- a/statemachine/state.py +++ b/statemachine/state.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING from typing import Any from typing import Dict +from typing import Generator from typing import List from weakref import ref @@ -103,7 +104,7 @@ class State: Given a few states... - >>> draft = State("Draft", initial=True) + >>> draft = State(name="Draft", initial=True) >>> producing = State("Producing") @@ -112,7 +113,7 @@ class State: Transitions are declared using the :func:`State.to` or :func:`State.from_` (reversed) methods. >>> draft.to(producing) - TransitionList([Transition(State('Draft', ... + TransitionList([Transition('Draft', 'Producing', event='', internal=False)]) The result is a :ref:`TransitionList`. Don't worry about this internal class. @@ -137,7 +138,7 @@ class State: expressed using an alternative syntax: >>> draft.to.itself() - TransitionList([Transition(State('Draft', ... + TransitionList([Transition('Draft', 'Draft', event='', internal=False)]) You can even pass a list of target states to declare at once all transitions starting from the same state. @@ -183,7 +184,7 @@ def __init__( initial: bool = False, final: bool = False, parallel: bool = False, - states: Any = None, + states: "List[State] | None" = None, enter: Any = None, exit: Any = None, _callbacks: Any = None, @@ -291,6 +292,21 @@ def final(self): def parallel(self): return self._parallel + @property + def is_compound(self): + return bool(self.states) + + def ancestors(self, parent: "State | None" = None) -> Generator: + selected = self + while selected: + if parent and selected == parent: + break + yield selected + selected = selected.parent + + def is_descendant(self, state: "State | None") -> bool: + return state in self.ancestors() + class InstanceState(State): """ """ diff --git a/statemachine/statemachine.py b/statemachine/statemachine.py index d1767495..98cc52a8 100644 --- a/statemachine/statemachine.py +++ b/statemachine/statemachine.py @@ -270,7 +270,11 @@ def current_state_value(self): @current_state_value.setter def current_state_value(self, value): - if not isinstance(value, MutableSet) and value not in self.states_map: + if ( + value is not None + and not isinstance(value, MutableSet) + and value not in self.states_map + ): raise InvalidStateValue(value) setattr(self.model, self.state_field, value) @@ -313,11 +317,19 @@ def allowed_events(self) -> "List[Event]": """List of the current allowed events.""" return [getattr(self, event) for event in self.current_state.transitions.unique_events] - def _put_nonblocking(self, trigger_data: TriggerData): + def _put_nonblocking(self, trigger_data: TriggerData, internal: bool = False): """Put the trigger on the queue without blocking the caller.""" - self._engine.put(trigger_data) + self._engine.put(trigger_data, internal=internal) - def send(self, event: str, *args, delay: float = 0, event_id: "str | None" = None, **kwargs): + def send( + self, + event: str, + *args, + delay: float = 0, + event_id: "str | None" = None, + internal: bool = False, + **kwargs, + ): """Send an :ref:`Event` to the state machine. :param event: The trigger for the state machine, specified as an event id string. @@ -334,16 +346,30 @@ def send(self, event: str, *args, delay: float = 0, event_id: "str | None" = Non delay = ( delay if delay else know_event and know_event.delay or 0 ) # first the param, then the event, or 0 - event_instance = BoundEvent(id=event, name=event_name, delay=delay, _sm=self) + event_instance = BoundEvent( + id=event, name=event_name, delay=delay, internal=internal, _sm=self + ) result = event_instance(*args, event_id=event_id, **kwargs) if not isawaitable(result): return result return run_async_from_sync(result) + def raise_(self, event: str, *args, delay: float = 0, event_id: "str | None" = None, **kwargs): + """Send an :ref:`Event` to the state machine in the internal event queue. + + Events on the internal queue are processed immediately on the current step of the + interpreter. + + .. seealso:: + + See: :ref:`triggering events`. + """ + return self.send(event, *args, delay=delay, event_id=event_id, internal=True, **kwargs) + def cancel_event(self, send_id: str): """Cancel all the delayed events with the given ``send_id``.""" self._engine.cancel_event(send_id) @property def is_terminated(self): - return not self._engine._running + return not self._engine.running diff --git a/statemachine/transition.py b/statemachine/transition.py index 483c714d..b250af22 100644 --- a/statemachine/transition.py +++ b/statemachine/transition.py @@ -76,8 +76,8 @@ def __init__( def __repr__(self): return ( - f"{type(self).__name__}({self.source!r}, {self.target!r}, event={self.event!r}, " - f"internal={self.internal!r})" + f"{type(self).__name__}({self.source.name!r}, {self.target.name!r}, " + f"event={self.event!r}, internal={self.internal!r})" ) def __str__(self): diff --git a/tests/conftest.py b/tests/conftest.py index 8a304e12..9430b3af 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,10 @@ collect_ignore_glob.append("*_positional_only.py") +# TODO: Return django to collect +collect_ignore_glob.append("django") + + @pytest.fixture() def current_time(): return datetime.now() diff --git a/tests/examples/all_actions_machine.py b/tests/examples/all_actions_machine.py index cf89cfa5..ef2e62e0 100644 --- a/tests/examples/all_actions_machine.py +++ b/tests/examples/all_actions_machine.py @@ -152,7 +152,7 @@ def on_exit_final(self): # Only before/on actions have their result collected. result = machine.go() -assert result == [ +expected = [ "before_transition", "before_go_inline_1", "before_go_inline_2", @@ -164,6 +164,7 @@ def on_exit_final(self): "go_on_decor", "on_go", ] +assert result == expected # %% # Checking the method resolution order diff --git a/tests/scxml/test_scxml_cases.py b/tests/scxml/test_scxml_cases.py index f0d85c14..4494d95a 100644 --- a/tests/scxml/test_scxml_cases.py +++ b/tests/scxml/test_scxml_cases.py @@ -24,9 +24,12 @@ class DebugListener: def on_transition(self, event: Event, source: State, target: State, event_data): self.events.append( - f"{source and source.id} -- " - f"{event and event.id}{event_data.trigger_data.kwargs} --> " - f"{target.id}" + ( + f"{source and source.id}", + f"{event and event.id}", + f"{event_data.trigger_data.kwargs}", + f"{target.id}", + ) ) diff --git a/tests/scxml/w3c/mandatory/test189-fail-fail.scxml b/tests/scxml/w3c/mandatory/test189.scxml similarity index 100% rename from tests/scxml/w3c/mandatory/test189-fail-fail.scxml rename to tests/scxml/w3c/mandatory/test189.scxml diff --git a/tests/scxml/w3c/mandatory/test190-fail-fail.scxml b/tests/scxml/w3c/mandatory/test190-fail.scxml similarity index 100% rename from tests/scxml/w3c/mandatory/test190-fail-fail.scxml rename to tests/scxml/w3c/mandatory/test190-fail.scxml diff --git a/tests/scxml/w3c/mandatory/test191-fail-fail.scxml b/tests/scxml/w3c/mandatory/test191-fail.scxml similarity index 97% rename from tests/scxml/w3c/mandatory/test191-fail-fail.scxml rename to tests/scxml/w3c/mandatory/test191-fail.scxml index 50e226e3..a54c9beb 100644 --- a/tests/scxml/w3c/mandatory/test191-fail-fail.scxml +++ b/tests/scxml/w3c/mandatory/test191-fail.scxml @@ -8,7 +8,7 @@ hang. --> initial="s0" version="1.0" datamodel="ecmascript"> - + diff --git a/tests/scxml/w3c/mandatory/test192-fail-fail.scxml b/tests/scxml/w3c/mandatory/test192-fail.scxml similarity index 59% rename from tests/scxml/w3c/mandatory/test192-fail-fail.scxml rename to tests/scxml/w3c/mandatory/test192-fail.scxml index e4458743..da7251fd 100644 --- a/tests/scxml/w3c/mandatory/test192-fail-fail.scxml +++ b/tests/scxml/w3c/mandatory/test192-fail.scxml @@ -1,52 +1,57 @@ - + - + - - - + + - + - + - + - - + + - + - + - + - + diff --git a/tests/scxml/w3c/mandatory/test207-fail-fail.scxml b/tests/scxml/w3c/mandatory/test207-fail.scxml similarity index 100% rename from tests/scxml/w3c/mandatory/test207-fail-fail.scxml rename to tests/scxml/w3c/mandatory/test207-fail.scxml diff --git a/tests/scxml/w3c/mandatory/test215-fail-fail.scxml b/tests/scxml/w3c/mandatory/test215-fail-fail.scxml deleted file mode 100644 index 0c3634aa..00000000 --- a/tests/scxml/w3c/mandatory/test215-fail-fail.scxml +++ /dev/null @@ -1,29 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tests/scxml/w3c/mandatory/test215-fail.scxml b/tests/scxml/w3c/mandatory/test215-fail.scxml new file mode 100644 index 00000000..0d34c3f4 --- /dev/null +++ b/tests/scxml/w3c/mandatory/test215-fail.scxml @@ -0,0 +1,41 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/scxml/w3c/mandatory/test216-fail.scxml b/tests/scxml/w3c/mandatory/test216-fail.scxml index d94d5890..d91ae4c0 100644 --- a/tests/scxml/w3c/mandatory/test216-fail.scxml +++ b/tests/scxml/w3c/mandatory/test216-fail.scxml @@ -11,7 +11,7 @@ the runtime value is used, the invocation will succeed --> - + diff --git a/tests/scxml/w3c/mandatory/test220-fail.scxml b/tests/scxml/w3c/mandatory/test220-fail.scxml index 2f3370c8..0b0cd3e7 100644 --- a/tests/scxml/w3c/mandatory/test220-fail.scxml +++ b/tests/scxml/w3c/mandatory/test220-fail.scxml @@ -4,7 +4,7 @@ - + diff --git a/tests/scxml/w3c/mandatory/test277-fail.scxml b/tests/scxml/w3c/mandatory/test277.scxml similarity index 91% rename from tests/scxml/w3c/mandatory/test277-fail.scxml rename to tests/scxml/w3c/mandatory/test277.scxml index 57bf068b..ed495641 100644 --- a/tests/scxml/w3c/mandatory/test277-fail.scxml +++ b/tests/scxml/w3c/mandatory/test277.scxml @@ -1,4 +1,5 @@ - diff --git a/tests/scxml/w3c/mandatory/test279-fail.scxml b/tests/scxml/w3c/mandatory/test279.scxml similarity index 90% rename from tests/scxml/w3c/mandatory/test279-fail.scxml rename to tests/scxml/w3c/mandatory/test279.scxml index 36fe69d3..79e9cf5f 100644 --- a/tests/scxml/w3c/mandatory/test279-fail.scxml +++ b/tests/scxml/w3c/mandatory/test279.scxml @@ -8,6 +8,7 @@ early binding variables are assigned values at init time --> + diff --git a/tests/scxml/w3c/mandatory/test280-fail.scxml b/tests/scxml/w3c/mandatory/test280.scxml similarity index 88% rename from tests/scxml/w3c/mandatory/test280-fail.scxml rename to tests/scxml/w3c/mandatory/test280.scxml index 0f516bf5..814b67ec 100644 --- a/tests/scxml/w3c/mandatory/test280-fail.scxml +++ b/tests/scxml/w3c/mandatory/test280.scxml @@ -1,5 +1,6 @@ - @@ -20,7 +21,6 @@ possible access it there and assign its value to var1 --> - diff --git a/tests/scxml/w3c/mandatory/test286-fail.scxml b/tests/scxml/w3c/mandatory/test286.scxml similarity index 86% rename from tests/scxml/w3c/mandatory/test286-fail.scxml rename to tests/scxml/w3c/mandatory/test286.scxml index 5be6f46f..9ac71547 100644 --- a/tests/scxml/w3c/mandatory/test286-fail.scxml +++ b/tests/scxml/w3c/mandatory/test286.scxml @@ -1,4 +1,5 @@ - diff --git a/tests/scxml/w3c/mandatory/test309-fail.scxml b/tests/scxml/w3c/mandatory/test309.scxml similarity index 81% rename from tests/scxml/w3c/mandatory/test309-fail.scxml rename to tests/scxml/w3c/mandatory/test309.scxml index 54396083..645268f9 100644 --- a/tests/scxml/w3c/mandatory/test309-fail.scxml +++ b/tests/scxml/w3c/mandatory/test309.scxml @@ -1,4 +1,5 @@ - diff --git a/tests/scxml/w3c/mandatory/test311-fail.scxml b/tests/scxml/w3c/mandatory/test311.scxml similarity index 83% rename from tests/scxml/w3c/mandatory/test311-fail.scxml rename to tests/scxml/w3c/mandatory/test311.scxml index c50e57b7..700ec79d 100644 --- a/tests/scxml/w3c/mandatory/test311-fail.scxml +++ b/tests/scxml/w3c/mandatory/test311.scxml @@ -1,4 +1,5 @@ - + + diff --git a/tests/scxml/w3c/mandatory/test322-fail.scxml b/tests/scxml/w3c/mandatory/test322.scxml similarity index 90% rename from tests/scxml/w3c/mandatory/test322-fail.scxml rename to tests/scxml/w3c/mandatory/test322.scxml index 31c8e2cc..21c7f28b 100644 --- a/tests/scxml/w3c/mandatory/test322-fail.scxml +++ b/tests/scxml/w3c/mandatory/test322.scxml @@ -1,4 +1,5 @@ - diff --git a/tests/scxml/w3c/mandatory/test350-fail.scxml b/tests/scxml/w3c/mandatory/test350-fail.scxml index d456cd09..8d3e07de 100644 --- a/tests/scxml/w3c/mandatory/test350-fail.scxml +++ b/tests/scxml/w3c/mandatory/test350-fail.scxml @@ -10,7 +10,7 @@ able to send an event to itself using its own session ID as the target --> - + diff --git a/tests/scxml/w3c/mandatory/test351-fail.scxml b/tests/scxml/w3c/mandatory/test351-fail.scxml index 6d9c4969..0a40f0f3 100644 --- a/tests/scxml/w3c/mandatory/test351-fail.scxml +++ b/tests/scxml/w3c/mandatory/test351-fail.scxml @@ -11,7 +11,7 @@ otherwise --> - + @@ -30,7 +30,7 @@ otherwise --> - + diff --git a/tests/scxml/w3c/mandatory/test352-fail.scxml b/tests/scxml/w3c/mandatory/test352-fail.scxml index 32ef2768..b45006a4 100644 --- a/tests/scxml/w3c/mandatory/test352-fail.scxml +++ b/tests/scxml/w3c/mandatory/test352-fail.scxml @@ -8,7 +8,7 @@ - + diff --git a/tests/scxml/w3c/mandatory/test354.scxml b/tests/scxml/w3c/mandatory/test354.scxml index ac4e8842..f7d19c8c 100644 --- a/tests/scxml/w3c/mandatory/test354.scxml +++ b/tests/scxml/w3c/mandatory/test354.scxml @@ -11,7 +11,7 @@ and that correct values are used --> - + @@ -38,7 +38,7 @@ and that correct values are used --> - + foo diff --git a/tests/scxml/w3c/mandatory/test403a-fail.scxml b/tests/scxml/w3c/mandatory/test403a.scxml similarity index 100% rename from tests/scxml/w3c/mandatory/test403a-fail.scxml rename to tests/scxml/w3c/mandatory/test403a.scxml diff --git a/tests/scxml/w3c/mandatory/test409.scxml b/tests/scxml/w3c/mandatory/test409-fail.scxml similarity index 100% rename from tests/scxml/w3c/mandatory/test409.scxml rename to tests/scxml/w3c/mandatory/test409-fail.scxml diff --git a/tests/scxml/w3c/mandatory/test419-fail.scxml b/tests/scxml/w3c/mandatory/test419.scxml similarity index 100% rename from tests/scxml/w3c/mandatory/test419-fail.scxml rename to tests/scxml/w3c/mandatory/test419.scxml diff --git a/tests/scxml/w3c/mandatory/test421-fail.scxml b/tests/scxml/w3c/mandatory/test421.scxml similarity index 100% rename from tests/scxml/w3c/mandatory/test421-fail.scxml rename to tests/scxml/w3c/mandatory/test421.scxml diff --git a/tests/scxml/w3c/mandatory/test423-fail.scxml b/tests/scxml/w3c/mandatory/test423.scxml similarity index 100% rename from tests/scxml/w3c/mandatory/test423-fail.scxml rename to tests/scxml/w3c/mandatory/test423.scxml diff --git a/tests/scxml/w3c/mandatory/test487-fail.scxml b/tests/scxml/w3c/mandatory/test487.scxml similarity index 100% rename from tests/scxml/w3c/mandatory/test487-fail.scxml rename to tests/scxml/w3c/mandatory/test487.scxml diff --git a/tests/scxml/w3c/mandatory/test495-fail.scxml b/tests/scxml/w3c/mandatory/test495.scxml similarity index 100% rename from tests/scxml/w3c/mandatory/test495-fail.scxml rename to tests/scxml/w3c/mandatory/test495.scxml diff --git a/tests/scxml/w3c/optional/test509-fail.scxml b/tests/scxml/w3c/optional/test509-fail.scxml index f8e13dc5..e898b41c 100644 --- a/tests/scxml/w3c/optional/test509-fail.scxml +++ b/tests/scxml/w3c/optional/test509-fail.scxml @@ -5,7 +5,7 @@ at the accessURI--> - + diff --git a/tests/scxml/w3c/optional/test518-fail.scxml b/tests/scxml/w3c/optional/test518-fail.scxml index 3291f888..c09c975b 100644 --- a/tests/scxml/w3c/optional/test518-fail.scxml +++ b/tests/scxml/w3c/optional/test518-fail.scxml @@ -7,7 +7,7 @@ - + diff --git a/tests/scxml/w3c/optional/test519-fail.scxml b/tests/scxml/w3c/optional/test519-fail.scxml index e13def86..f6d8e819 100644 --- a/tests/scxml/w3c/optional/test519-fail.scxml +++ b/tests/scxml/w3c/optional/test519-fail.scxml @@ -4,7 +4,7 @@ initial="s0" datamodel="ecmascript" version="1.0"> - + diff --git a/tests/scxml/w3c/optional/test520-fail.scxml b/tests/scxml/w3c/optional/test520-fail.scxml index 04af8bce..0f23a7b2 100644 --- a/tests/scxml/w3c/optional/test520-fail.scxml +++ b/tests/scxml/w3c/optional/test520-fail.scxml @@ -4,7 +4,7 @@ initial="s0" datamodel="ecmascript" version="1.0"> - + this is some content diff --git a/tests/scxml/w3c/optional/test522.scxml b/tests/scxml/w3c/optional/test522.scxml index d1d22262..74aa3941 100644 --- a/tests/scxml/w3c/optional/test522.scxml +++ b/tests/scxml/w3c/optional/test522.scxml @@ -5,7 +5,7 @@ to send a message to the processor --> initial="s0" datamodel="ecmascript" version="1.0"> - + diff --git a/tests/scxml/w3c/optional/test534-fail.scxml b/tests/scxml/w3c/optional/test534-fail.scxml index f3079185..4adc62e2 100644 --- a/tests/scxml/w3c/optional/test534-fail.scxml +++ b/tests/scxml/w3c/optional/test534-fail.scxml @@ -4,7 +4,7 @@ initial="s0" datamodel="python" version="1.0"> - + diff --git a/tests/test_copy.py b/tests/test_copy.py index 65579887..8e557a15 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -8,11 +8,9 @@ from statemachine import State from statemachine import StateMachine -from statemachine.exceptions import TransitionNotAllowed from statemachine.states import States logger = logging.getLogger(__name__) -DEBUG = logging.DEBUG def copy_pickle(obj): @@ -63,30 +61,20 @@ class MySM(StateMachine): publish = draft.to(published, cond="let_me_be_visible") - def on_transition(self, event: str): - logger.debug(f"{self.__class__.__name__} recorded {event} transition") - def let_me_be_visible(self): - logger.debug(f"{type(self).__name__} let_me_be_visible: True") return True class MyModel: def __init__(self, name: str) -> None: self.name = name - self.let_me_be_visible = False + self._let_me_be_visible = False def __repr__(self) -> str: return f"{type(self).__name__}@{id(self)}({self.name!r})" - def on_transition(self, event: str): - logger.debug(f"{type(self).__name__}({self.name!r}) recorded {event} transition") - @property def let_me_be_visible(self): - logger.debug( - f"{type(self).__name__}({self.name!r}) let_me_be_visible: {self._let_me_be_visible}" - ) return self._let_me_be_visible @let_me_be_visible.setter @@ -96,16 +84,19 @@ def let_me_be_visible(self, value): def test_copy(copy_method): sm = MySM(MyModel("main_model")) - sm2 = copy_method(sm) - with pytest.raises(TransitionNotAllowed): - sm2.send("publish") + assert sm.model is not sm2.model + assert sm.model.name == sm2.model.name + assert sm2.current_state == sm.current_state + sm2.model.let_me_be_visible = True + sm2.send("publish") + assert sm2.current_state == sm.published -def test_copy_with_listeners(caplog, copy_method): - model1 = MyModel("main_model") +def test_copy_with_listeners(copy_method): + model1 = MyModel("main_model") sm1 = MySM(model1) listener_1 = MyModel("observer_1") @@ -116,52 +107,20 @@ def test_copy_with_listeners(caplog, copy_method): sm2 = copy_method(sm1) assert sm1.model is not sm2.model + assert len(sm1._listeners) == len(sm2._listeners) + assert all( + listener.name == copied_listener.name + for listener, copied_listener in zip( + sm1._listeners.values(), sm2._listeners.values(), strict=False + ) + ) + + sm2.model.let_me_be_visible = True + for listener in sm2._listeners.values(): + listener.let_me_be_visible = True - caplog.set_level(logging.DEBUG, logger="tests") - - def assertions(sm, _reference): - caplog.clear() - if not sm._listeners: - pytest.fail("did not found any observer") - - for listener in sm._listeners.values(): - listener.let_me_be_visible = False - - with pytest.raises(TransitionNotAllowed): - sm.send("publish") - - sm.model.let_me_be_visible = True - - for listener in sm._listeners.values(): - with pytest.raises(TransitionNotAllowed): - sm.send("publish") - - listener.let_me_be_visible = True - - sm.send("publish") - - assert caplog.record_tuples == [ - ("tests.test_copy", DEBUG, "MySM let_me_be_visible: True"), - ("tests.test_copy", DEBUG, "MyModel('main_model') let_me_be_visible: False"), - ("tests.test_copy", DEBUG, "MySM let_me_be_visible: True"), - ("tests.test_copy", DEBUG, "MyModel('main_model') let_me_be_visible: True"), - ("tests.test_copy", DEBUG, "MyModel('observer_1') let_me_be_visible: False"), - ("tests.test_copy", DEBUG, "MySM let_me_be_visible: True"), - ("tests.test_copy", DEBUG, "MyModel('main_model') let_me_be_visible: True"), - ("tests.test_copy", DEBUG, "MyModel('observer_1') let_me_be_visible: True"), - ("tests.test_copy", DEBUG, "MyModel('observer_2') let_me_be_visible: False"), - ("tests.test_copy", DEBUG, "MySM let_me_be_visible: True"), - ("tests.test_copy", DEBUG, "MyModel('main_model') let_me_be_visible: True"), - ("tests.test_copy", DEBUG, "MyModel('observer_1') let_me_be_visible: True"), - ("tests.test_copy", DEBUG, "MyModel('observer_2') let_me_be_visible: True"), - ("tests.test_copy", DEBUG, "MySM recorded publish transition"), - ("tests.test_copy", DEBUG, "MyModel('main_model') recorded publish transition"), - ("tests.test_copy", DEBUG, "MyModel('observer_1') recorded publish transition"), - ("tests.test_copy", DEBUG, "MyModel('observer_2') recorded publish transition"), - ] - - assertions(sm1, "original") - assertions(sm2, "copy") + sm2.send("publish") + assert sm2.current_state == sm1.published def test_copy_with_enum(copy_method): diff --git a/tests/test_transitions.py b/tests/test_transitions.py index 6c0c2c41..a1406391 100644 --- a/tests/test_transitions.py +++ b/tests/test_transitions.py @@ -10,12 +10,7 @@ def test_transition_representation(campaign_machine): s = repr([t for t in campaign_machine.draft.transitions if t.event == "produce"][0]) - assert s == ( - "Transition(" - "State('Draft', id='draft', value='draft', initial=True, final=False), " - "State('Being produced', id='producing', value='producing', " - "initial=False, final=False), event='produce', internal=False)" - ) + assert s == ("Transition('Draft', 'Being produced', event='produce', internal=False)") def test_list_machine_events(classic_traffic_light_machine): diff --git a/tests/testcases/__init__.py b/tests/testcases/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/testcases/issue434.md b/tests/testcases/issue434.md deleted file mode 100644 index 3e029121..00000000 --- a/tests/testcases/issue434.md +++ /dev/null @@ -1,87 +0,0 @@ -### Issue 434 - -A StateMachine that exercises the example given on issue -#[434](https://github.com/fgmacedo/python-statemachine/issues/434). - - -```py ->>> from time import sleep ->>> from statemachine import StateMachine, State - ->>> class Model: -... def __init__(self, data: dict): -... self.data = data - ->>> class DataCheckerMachine(StateMachine): -... check_data = State(initial=True) -... data_good = State(final=True) -... data_bad = State(final=True) -... -... MAX_CYCLE_COUNT = 10 -... cycle_count = 0 -... -... cycle = ( -... check_data.to(data_good, cond="data_looks_good") -... | check_data.to(data_bad, cond="max_cycle_reached") -... | check_data.to.itself(internal=True) -... ) -... -... def data_looks_good(self): -... return self.model.data.get("value") > 10.0 -... -... def max_cycle_reached(self): -... return self.cycle_count > self.MAX_CYCLE_COUNT -... -... def after_cycle(self, event: str, source: State, target: State): -... print(f'Running {event} {self.cycle_count} from {source!s} to {target!s}.') -... self.cycle_count += 1 -... - -``` - -Run until we reach the max cycle without success: - -```py ->>> data = {"value": 1} ->>> sm1 = DataCheckerMachine(Model(data)) ->>> cycle_rate = 0.1 ->>> while not sm1.current_state.final: -... sm1.cycle() -... sleep(cycle_rate) -Running cycle 0 from Check data to Check data. -Running cycle 1 from Check data to Check data. -Running cycle 2 from Check data to Check data. -Running cycle 3 from Check data to Check data. -Running cycle 4 from Check data to Check data. -Running cycle 5 from Check data to Check data. -Running cycle 6 from Check data to Check data. -Running cycle 7 from Check data to Check data. -Running cycle 8 from Check data to Check data. -Running cycle 9 from Check data to Check data. -Running cycle 10 from Check data to Check data. -Running cycle 11 from Check data to Data bad. - -``` - - -Run simulating that the data turns good on the 5th iteration: - -```py ->>> data = {"value": 1} ->>> sm2 = DataCheckerMachine(Model(data)) ->>> cycle_rate = 0.1 ->>> while not sm2.current_state.final: -... sm2.cycle() -... if sm2.cycle_count == 5: -... print("Now data looks good!") -... data["value"] = 20 -... sleep(cycle_rate) -Running cycle 0 from Check data to Check data. -Running cycle 1 from Check data to Check data. -Running cycle 2 from Check data to Check data. -Running cycle 3 from Check data to Check data. -Running cycle 4 from Check data to Check data. -Now data looks good! -Running cycle 5 from Check data to Data good. - -``` diff --git a/tests/testcases/issue480.md b/tests/testcases/issue480.md deleted file mode 100644 index 71b78d37..00000000 --- a/tests/testcases/issue480.md +++ /dev/null @@ -1,43 +0,0 @@ - - -### Issue 480 - -A StateMachine that exercises the example given on issue -#[480](https://github.com/fgmacedo/python-statemachine/issues/480). - -Should be possible to trigger an event on the initial state activation handler. - -```py ->>> from statemachine import StateMachine, State ->>> ->>> class MyStateMachine(StateMachine): -... State_1 = State(initial=True) -... State_2 = State(final=True) -... Trans_1 = State_1.to(State_2) -... -... def __init__(self): -... super(MyStateMachine, self).__init__() -... -... def on_enter_State_1(self): -... print("Entering State_1 state") -... self.long_running_task() -... -... def on_exit_State_1(self): -... print("Exiting State_1 state") -... -... def on_enter_State_2(self): -... print("Entering State_2 state") -... -... def long_running_task(self): -... print("long running task process started") -... self.Trans_1() -... print("long running task process ended") -... ->>> sm = MyStateMachine() -Entering State_1 state -long running task process started -long running task process ended -Exiting State_1 state -Entering State_2 state - -``` diff --git a/tests/testcases/test_issue434.py b/tests/testcases/test_issue434.py new file mode 100644 index 00000000..59d682dc --- /dev/null +++ b/tests/testcases/test_issue434.py @@ -0,0 +1,73 @@ +from time import sleep + +import pytest + +from statemachine import State +from statemachine import StateMachine + + +class Model: + def __init__(self, data: dict): + self.data = data + + +class DataCheckerMachine(StateMachine): + check_data = State(initial=True) + data_good = State(final=True) + data_bad = State(final=True) + + MAX_CYCLE_COUNT = 10 + cycle_count = 0 + + cycle = ( + check_data.to(data_good, cond="data_looks_good") + | check_data.to(data_bad, cond="max_cycle_reached") + | check_data.to.itself(internal=True) + ) + + def data_looks_good(self): + return self.model.data.get("value") > 10.0 + + def max_cycle_reached(self): + return self.cycle_count > self.MAX_CYCLE_COUNT + + def after_cycle(self, event: str, source: State, target: State): + print(f"Running {event} {self.cycle_count} from {source!s} to {target!s}.") + self.cycle_count += 1 + + +@pytest.fixture() +def initial_data(): + return {"value": 1} + + +@pytest.fixture() +def data_checker_machine(initial_data): + return DataCheckerMachine(Model(initial_data)) + + +def test_max_cycle_without_success(data_checker_machine): + sm = data_checker_machine + cycle_rate = 0.1 + + while not sm.current_state.final: + sm.cycle() + sleep(cycle_rate) + + assert sm.current_state == sm.data_bad + assert sm.cycle_count == 12 + + +def test_data_turns_good_mid_cycle(initial_data): + sm = DataCheckerMachine(Model(initial_data)) + cycle_rate = 0.1 + + while not sm.current_state.final: + sm.cycle() + if sm.cycle_count == 5: + print("Now data looks good!") + sm.model.data["value"] = 20 + sleep(cycle_rate) + + assert sm.current_state == sm.data_good + assert sm.cycle_count == 6 # Transition occurs at the 6th cycle diff --git a/tests/testcases/test_issue480.py b/tests/testcases/test_issue480.py new file mode 100644 index 00000000..4bea763a --- /dev/null +++ b/tests/testcases/test_issue480.py @@ -0,0 +1,56 @@ +""" + +### Issue 480 + +A StateMachine that exercises the example given on issue +#[480](https://github.com/fgmacedo/python-statemachine/issues/480). + +Should be possible to trigger an event on the initial state activation handler. +""" + +from unittest.mock import MagicMock +from unittest.mock import call + +from statemachine import State +from statemachine import StateMachine + + +class MyStateMachine(StateMachine): + state_1 = State(initial=True) + state_2 = State(final=True) + + trans_1 = state_1.to(state_2) + + def __init__(self): + self.mock = MagicMock() + super().__init__() + + def on_enter_state_1(self): + self.mock("on_enter_state_1") + self.long_running_task() + + def on_exit_state_1(self): + self.mock("on_exit_state_1") + + def on_enter_state_2(self): + self.mock("on_enter_state_2") + + def long_running_task(self): + self.mock("long_running_task_started") + self.trans_1() + self.mock("long_running_task_ended") + + +def test_initial_state_activation_handler(): + sm = MyStateMachine() + + expected_calls = [ + call("on_enter_state_1"), + call("long_running_task_started"), + call("long_running_task_ended"), + call("on_exit_state_1"), + call("on_enter_state_2"), + ] + + assert sm.mock.mock_calls == expected_calls + assert sm.current_state == sm.state_2