From 7724121403e8a19ebd18cee26dbfb1b216cbde9f Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Wed, 11 Dec 2024 12:11:33 +0100 Subject: [PATCH 01/22] Defer import of aio_pika --- src/plumpy/exceptions.py | 12 +++++++++--- src/plumpy/processes.py | 5 +++-- src/plumpy/rmq/__init__.py | 4 ++++ src/plumpy/rmq/exceptions.py | 11 +++++++++++ 4 files changed, 27 insertions(+), 5 deletions(-) create mode 100644 src/plumpy/rmq/__init__.py create mode 100644 src/plumpy/rmq/exceptions.py diff --git a/src/plumpy/exceptions.py b/src/plumpy/exceptions.py index 70b5aa2d..2f290e6a 100644 --- a/src/plumpy/exceptions.py +++ b/src/plumpy/exceptions.py @@ -1,7 +1,13 @@ # -*- coding: utf-8 -*- from typing import Optional -__all__ = ['ClosedError', 'InvalidStateError', 'KilledError', 'PersistenceError', 'UnsuccessfulResult'] +__all__ = [ + 'ClosedError', + 'InvalidStateError', + 'KilledError', + 'PersistenceError', + 'UnsuccessfulResult', +] class KilledError(Exception): @@ -9,8 +15,7 @@ class KilledError(Exception): class InvalidStateError(Exception): - """ - Raised when an operation is attempted that requires the process to be in a state + """Raised when an operation is attempted that requires the process to be in a state that is different from the current state """ @@ -33,3 +38,4 @@ class PersistenceError(Exception): class ClosedError(Exception): """Raised when an mutable operation is attempted on a closed process""" + diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 409374d0..7c5e08fc 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -38,7 +38,6 @@ import kiwipy import yaml -from aio_pika.exceptions import ChannelInvalidStateError, ConnectionClosed from . import ( events, @@ -735,6 +734,8 @@ def on_entering(self, state: process_states.State) -> None: call_with_super_check(self.on_except, state.get_exc_info()) # type: ignore def on_entered(self, from_state: Optional[process_states.State]) -> None: + from plumpy.rmq.exceptions import CommunicatorChannelInvalidStateError, CommunicatorConnectionClosed + # Map these onto direct functions that the subclass can implement state_label = self._state.LABEL if state_label == process_states.ProcessState.RUNNING: @@ -754,7 +755,7 @@ def on_entered(self, from_state: Optional[process_states.State]) -> None: self.logger.info('Process<%s>: Broadcasting state change: %s', self.pid, subject) try: self._communicator.broadcast_send(body=None, sender=self.pid, subject=subject) - except (ConnectionClosed, ChannelInvalidStateError): + except (CommunicatorConnectionClosed, CommunicatorChannelInvalidStateError): message = 'Process<%s>: no connection available to broadcast state change from %s to %s' self.logger.warning(message, self.pid, from_label, self.state.value) except kiwipy.TimeoutError: diff --git a/src/plumpy/rmq/__init__.py b/src/plumpy/rmq/__init__.py new file mode 100644 index 00000000..31d97783 --- /dev/null +++ b/src/plumpy/rmq/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- +from .exceptions import * + +__all__ = exceptions.__all__ diff --git a/src/plumpy/rmq/exceptions.py b/src/plumpy/rmq/exceptions.py new file mode 100644 index 00000000..b15d51c4 --- /dev/null +++ b/src/plumpy/rmq/exceptions.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- +from aio_pika.exceptions import ChannelInvalidStateError, ConnectionClosed + +__all__ = [ + 'CommunicatorChannelInvalidStateError', + 'CommunicatorConnectionClosed', +] + +# Alias aio_pika +CommunicatorConnectionClosed = ConnectionClosed +CommunicatorChannelInvalidStateError = ChannelInvalidStateError From 97746b43702896ed4e5d9db1c1b6fa1532ea9032 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Wed, 11 Dec 2024 23:50:36 +0100 Subject: [PATCH 02/22] Explicit future implementation: distinguish concurrent.future.Future and asyncio.Future hand write wrap to kiwipy future (concurrent.futures.Future) kiwipy.Future -> concurrent.futures.Future --- .python-version | 1 + src/plumpy/__init__.py | 1 + src/plumpy/communications.py | 28 +-- src/plumpy/exceptions.py | 1 - src/plumpy/futures.py | 64 +++---- src/plumpy/message.py | 299 +++++++++++++++++++++++++++++++++ src/plumpy/process_comms.py | 3 +- src/plumpy/processes.py | 6 +- src/plumpy/rmq/exceptions.py | 3 + src/plumpy/rmq/futures.py | 111 ++++++++++++ src/plumpy/workchains.py | 4 +- tests/rmq/test_communicator.py | 8 +- tests/test_processes.py | 6 +- tests/utils.py | 4 +- 14 files changed, 454 insertions(+), 85 deletions(-) create mode 100644 .python-version create mode 100644 src/plumpy/message.py create mode 100644 src/plumpy/rmq/futures.py diff --git a/.python-version b/.python-version new file mode 100644 index 00000000..413c7e7e --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +aiida-core-dev-3.12 diff --git a/src/plumpy/__init__.py b/src/plumpy/__init__.py index 46cac83a..adc302ef 100644 --- a/src/plumpy/__init__.py +++ b/src/plumpy/__init__.py @@ -16,6 +16,7 @@ from .process_listener import * from .process_states import * from .processes import * +from .rmq import * from .utils import * from .workchains import * diff --git a/src/plumpy/communications.py b/src/plumpy/communications.py index 1d7e775b..04e39d58 100644 --- a/src/plumpy/communications.py +++ b/src/plumpy/communications.py @@ -15,7 +15,6 @@ 'DeliveryFailed', 'RemoteException', 'TaskRejected', - 'plum_to_kiwi_future', 'wrap_communicator', ] @@ -36,31 +35,6 @@ BroadcastSubscriber = Callable[[kiwipy.Communicator, Any, Any, Any, ID_TYPE], Any] -def plum_to_kiwi_future(plum_future: futures.Future) -> kiwipy.Future: - """ - Return a kiwi future that resolves to the outcome of the plum future - - :param plum_future: the plum future - :return: the kiwipy future - - """ - kiwi_future = kiwipy.Future() - - def on_done(_plum_future: futures.Future) -> None: - with kiwipy.capture_exceptions(kiwi_future): - if plum_future.cancelled(): - kiwi_future.cancel() - else: - result = plum_future.result() - # Did we get another future? In which case convert it too - if isinstance(result, futures.Future): - result = plum_to_kiwi_future(result) - kiwi_future.set_result(result) - - plum_future.add_done_callback(on_done) - return kiwi_future - - def convert_to_comm( callback: 'Subscriber', loop: Optional[asyncio.AbstractEventLoop] = None ) -> Callable[..., kiwipy.Future]: @@ -97,7 +71,7 @@ def converted(communicator: kiwipy.Communicator, *args: Any, **kwargs: Any) -> k msg_fn = functools.partial(coro, communicator, *args, **kwargs) task_future = futures.create_task(msg_fn, loop) - return plum_to_kiwi_future(task_future) + return wrap_to_concurrent_future(task_future) return converted diff --git a/src/plumpy/exceptions.py b/src/plumpy/exceptions.py index 2f290e6a..6f0c75a4 100644 --- a/src/plumpy/exceptions.py +++ b/src/plumpy/exceptions.py @@ -38,4 +38,3 @@ class PersistenceError(Exception): class ClosedError(Exception): """Raised when an mutable operation is attempted on a closed process""" - diff --git a/src/plumpy/futures.py b/src/plumpy/futures.py index f52a0d09..a467f5d8 100644 --- a/src/plumpy/futures.py +++ b/src/plumpy/futures.py @@ -4,24 +4,33 @@ """ import asyncio -from typing import Any, Awaitable, Callable, Optional +import contextlib +from typing import Any, Awaitable, Callable, Generator, Optional -import kiwipy +__all__ = ['CancellableAction', 'create_task', 'create_task'] -__all__ = ['CancelledError', 'Future', 'chain', 'copy_future', 'create_task', 'gather'] -CancelledError = kiwipy.CancelledError +class InvalidFutureError(Exception): + """Exception for when a future or action is in an invalid state""" -class InvalidStateError(Exception): - """Exception for when a future or action is in an invalid state""" +Future = asyncio.Future -copy_future = kiwipy.copy_future -chain = kiwipy.chain -gather = asyncio.gather +@contextlib.contextmanager +def capture_exceptions(future: Future[Any], ignore: tuple[type[BaseException], ...] = ()) -> Generator[None, Any, None]: + """ + Capture any exceptions in the context and set them as the result of the given future -Future = asyncio.Future + :param future: The future to the exception on + :param ignore: An optional list of exception types to ignore, these will be raised and not set on the future + """ + try: + yield + except ignore: + raise + except Exception as exception: + future.set_exception(exception) class CancellableAction(Future): @@ -46,10 +55,10 @@ def run(self, *args: Any, **kwargs: Any) -> None: :param kwargs: the keyword arguments to the action """ if self.done(): - raise InvalidStateError('Action has already been ran') + raise InvalidFutureError('Action has already been ran') try: - with kiwipy.capture_exceptions(self): + with capture_exceptions(self): self.set_result(self._action(*args, **kwargs)) finally: self._action = None # type: ignore @@ -70,38 +79,9 @@ def create_task(coro: Callable[[], Awaitable[Any]], loop: Optional[asyncio.Abstr future = loop.create_future() async def run_task() -> None: - with kiwipy.capture_exceptions(future): + with capture_exceptions(future): res = await coro() future.set_result(res) asyncio.run_coroutine_threadsafe(run_task(), loop) return future - - -def unwrap_kiwi_future(future: kiwipy.Future) -> kiwipy.Future: - """ - Create a kiwi future that represents the final results of a nested series of futures, - meaning that if the futures provided itself resolves to a future the returned - future will not resolve to a value until the final chain of futures is not a future - but a concrete value. If at any point in the chain a future resolves to an exception - then the returned future will also resolve to that exception. - - :param future: the future to unwrap - :return: the unwrapping future - - """ - unwrapping = kiwipy.Future() - - def unwrap(fut: kiwipy.Future) -> None: - if fut.cancelled(): - unwrapping.cancel() - else: - with kiwipy.capture_exceptions(unwrapping): - result = fut.result() - if isinstance(result, kiwipy.Future): - result.add_done_callback(unwrap) - else: - unwrapping.set_result(result) - - future.add_done_callback(unwrap) - return unwrapping diff --git a/src/plumpy/message.py b/src/plumpy/message.py new file mode 100644 index 00000000..47586d21 --- /dev/null +++ b/src/plumpy/message.py @@ -0,0 +1,299 @@ +# -*- coding: utf-8 -*- +"""Module for process level communication functions and classes""" + +import asyncio +import logging +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast + +from plumpy.coordinator import Communicator +from plumpy.exceptions import PersistenceError, TaskRejectedError + +from . import loaders, persistence +from .utils import PID_TYPE + +__all__ = [ + 'KILL_MSG', + 'PAUSE_MSG', + 'PLAY_MSG', + 'STATUS_MSG', + 'ProcessLauncher', + 'create_continue_body', + 'create_launch_body', +] + +if TYPE_CHECKING: + from .processes import Process + +INTENT_KEY = 'intent' +MESSAGE_KEY = 'message' + + +class Intent: + """Intent constants for a process message""" + + PLAY: str = 'play' + PAUSE: str = 'pause' + KILL: str = 'kill' + STATUS: str = 'status' + + +PAUSE_MSG = {INTENT_KEY: Intent.PAUSE} +PLAY_MSG = {INTENT_KEY: Intent.PLAY} +KILL_MSG = {INTENT_KEY: Intent.KILL} +STATUS_MSG = {INTENT_KEY: Intent.STATUS} + +TASK_KEY = 'task' +TASK_ARGS = 'args' +PERSIST_KEY = 'persist' +# Launch +PROCESS_CLASS_KEY = 'process_class' +ARGS_KEY = 'init_args' +KWARGS_KEY = 'init_kwargs' +NOWAIT_KEY = 'nowait' +# Continue +PID_KEY = 'pid' +TAG_KEY = 'tag' +# Task types +LAUNCH_TASK = 'launch' +CONTINUE_TASK = 'continue' +CREATE_TASK = 'create' + +LOGGER = logging.getLogger(__name__) + + +def create_launch_body( + process_class: str, + init_args: Optional[Sequence[Any]] = None, + init_kwargs: Optional[Dict[str, Any]] = None, + persist: bool = False, + loader: Optional[loaders.ObjectLoader] = None, + nowait: bool = True, +) -> Dict[str, Any]: + """ + Create a message body for the launch action + + :param process_class: the class of the process to launch + :param init_args: any initialisation positional arguments + :param init_kwargs: any initialisation keyword arguments + :param persist: persist this process if True, otherwise don't + :param loader: the loader to use to load the persisted process + :param nowait: wait for the process to finish before completing the task, otherwise just return the PID + :return: a dictionary with the body of the message to launch the process + :rtype: dict + """ + if loader is None: + loader = loaders.get_object_loader() + + msg_body = { + TASK_KEY: LAUNCH_TASK, + TASK_ARGS: { + PROCESS_CLASS_KEY: loader.identify_object(process_class), + PERSIST_KEY: persist, + NOWAIT_KEY: nowait, + ARGS_KEY: init_args, + KWARGS_KEY: init_kwargs, + }, + } + return msg_body + + +def create_continue_body(pid: 'PID_TYPE', tag: Optional[str] = None, nowait: bool = False) -> Dict[str, Any]: + """ + Create a message body to continue an existing process + :param pid: the pid of the existing process + :param tag: the optional persistence tag + :param nowait: wait for the process to finish before completing the task, otherwise just return the PID + :return: a dictionary with the body of the message to continue the process + + """ + msg_body = {TASK_KEY: CONTINUE_TASK, TASK_ARGS: {PID_KEY: pid, NOWAIT_KEY: nowait, TAG_KEY: tag}} + return msg_body + + +def create_create_body( + process_class: str, + init_args: Optional[Sequence[Any]] = None, + init_kwargs: Optional[Dict[str, Any]] = None, + persist: bool = False, + loader: Optional[loaders.ObjectLoader] = None, +) -> Dict[str, Any]: + """ + Create a message body to create a new process + :param process_class: the class of the process to launch + :param init_args: any initialisation positional arguments + :param init_kwargs: any initialisation keyword arguments + :param persist: persist this process if True, otherwise don't + :param loader: the loader to use to load the persisted process + :return: a dictionary with the body of the message to launch the process + + """ + if loader is None: + loader = loaders.get_object_loader() + + msg_body = { + TASK_KEY: CREATE_TASK, + TASK_ARGS: { + PROCESS_CLASS_KEY: loader.identify_object(process_class), + PERSIST_KEY: persist, + ARGS_KEY: init_args, + KWARGS_KEY: init_kwargs, + }, + } + return msg_body + + +class ProcessLauncher: + """ + Takes incoming task messages and uses them to launch processes. + + Expected format of task: + + For launch:: + + { + 'task': + 'process_class': + 'args': + 'kwargs': . + 'nowait': True or False + } + + For continue:: + + { + 'task': + 'pid': + 'nowait': True or False + } + """ + + def __init__( + self, + loop: Optional[asyncio.AbstractEventLoop] = None, + persister: Optional[persistence.Persister] = None, + load_context: Optional[persistence.LoadSaveContext] = None, + loader: Optional[loaders.ObjectLoader] = None, + ) -> None: + self._loop = loop + self._persister = persister + self._load_context = load_context if load_context is not None else persistence.LoadSaveContext() + + if loader is not None: + self._loader = loader + self._load_context = self._load_context.copyextend(loader=loader) + else: + self._loader = loaders.get_object_loader() + + async def __call__(self, communicator: Communicator, task: Dict[str, Any]) -> Union[PID_TYPE, Any]: + """ + Receive a task. + :param task: The task message + """ + task_type = task[TASK_KEY] + if task_type == LAUNCH_TASK: + return await self._launch(**task.get(TASK_ARGS, {})) + if task_type == CONTINUE_TASK: + return await self._continue(**task.get(TASK_ARGS, {})) + if task_type == CREATE_TASK: + return await self._create(**task.get(TASK_ARGS, {})) + + raise TaskRejectedError + + async def _launch( + self, + process_class: str, + persist: bool, + nowait: bool, + init_args: Optional[Sequence[Any]] = None, + init_kwargs: Optional[Dict[str, Any]] = None, + ) -> Union[PID_TYPE, Any]: + """ + Launch the process + + :param _communicator: the communicator + :param process_class: the process class to launch + :param persist: should the process be persisted + :param nowait: if True only return when the process finishes + :param init_args: positional arguments to the process constructor + :param init_kwargs: keyword arguments to the process constructor + :return: the pid of the created process or the outputs (if nowait=False) + """ + if persist and not self._persister: + raise PersistenceError('Cannot persist process, no persister') + + if init_args is None: + init_args = () + if init_kwargs is None: + init_kwargs = {} + + proc_class = self._loader.load_object(process_class) + proc = proc_class(*init_args, **init_kwargs) + if persist and self._persister is not None: + self._persister.save_checkpoint(proc) + + if nowait: + # XXX: can return a reference and gracefully use task to cancel itself when the upper call stack fails + asyncio.ensure_future(proc.step_until_terminated()) # noqa: RUF006 + return proc.pid + + await proc.step_until_terminated() + + return proc.future().result() + + async def _continue(self, pid: 'PID_TYPE', nowait: bool, tag: Optional[str] = None) -> Union[PID_TYPE, Any]: + """ + Continue the process + + :param _communicator: the communicator + :param pid: the pid of the process to continue + :param nowait: if True don't wait for the process to complete + :param tag: the checkpoint tag to continue from + """ + if not self._persister: + LOGGER.warning('rejecting task: cannot continue process<%d> because no persister is available', pid) + raise PersistenceError('Cannot continue process, no persister') + + # Do not catch exceptions here, because if these operations fail, the continue task should except and bubble up + saved_state = self._persister.load_checkpoint(pid, tag) + proc = cast('Process', saved_state.unbundle(self._load_context)) + + if nowait: + # XXX: can return a reference and gracefully use task to cancel itself when the upper call stack fails + asyncio.ensure_future(proc.step_until_terminated()) # noqa: RUF006 + return proc.pid + + await proc.step_until_terminated() + + return proc.future().result() + + async def _create( + self, + process_class: str, + persist: bool, + init_args: Optional[Sequence[Any]] = None, + init_kwargs: Optional[Dict[str, Any]] = None, + ) -> 'PID_TYPE': + """ + Create the process + + :param _communicator: the communicator + :param process_class: the process class to create + :param persist: should the process be persisted + :param init_args: positional arguments to the process constructor + :param init_kwargs: keyword arguments to the process constructor + :return: the pid of the created process + """ + if persist and not self._persister: + raise PersistenceError('Cannot persist process, no persister') + + if init_args is None: + init_args = () + if init_kwargs is None: + init_kwargs = {} + + proc_class = self._loader.load_object(process_class) + proc = proc_class(*init_args, **init_kwargs) + if persist and self._persister is not None: + self._persister.save_checkpoint(proc) + + return proc.pid diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index 2d6b3bf4..98f7128e 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -475,11 +475,10 @@ def execute_process( :param no_reply: if True, this call will be fire-and-forget, i.e. no return value :return: the result of executing the process """ - message = create_create_body(process_class, init_args, init_kwargs, persist=True, loader=loader) execute_future = kiwipy.Future() - create_future = futures.unwrap_kiwi_future(self._communicator.task_send(message)) + create_future = self._communicator.task_send(message) def on_created(_: Any) -> None: with kiwipy.capture_exceptions(execute_future): diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 7c5e08fc..89b84ef4 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -977,7 +977,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: MessageType) -> Any: def broadcast_receive( self, _comm: kiwipy.Communicator, msg: MessageType, sender: Any, subject: Any, correlation_id: Any - ) -> Optional[kiwipy.Future]: + ) -> Optional[concurrent.futures.Future]: """ Coroutine called when the process receives a message from the communicator @@ -1002,7 +1002,7 @@ def broadcast_receive( return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None)) return None - def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> kiwipy.Future: + def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> concurrent.futures.Future: """ Schedule a call to a callback as a result of an RPC communication call, this will return a future that resolves to the final result (even after one or more layer of futures being @@ -1017,7 +1017,7 @@ def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) :return: a kiwi future that resolves to the outcome of the callback """ - kiwi_future = kiwipy.Future() + kiwi_future = concurrent.futures.Future() async def run_callback() -> None: with kiwipy.capture_exceptions(kiwi_future): diff --git a/src/plumpy/rmq/exceptions.py b/src/plumpy/rmq/exceptions.py index b15d51c4..02eb3c97 100644 --- a/src/plumpy/rmq/exceptions.py +++ b/src/plumpy/rmq/exceptions.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import kiwipy from aio_pika.exceptions import ChannelInvalidStateError, ConnectionClosed __all__ = [ @@ -9,3 +10,5 @@ # Alias aio_pika CommunicatorConnectionClosed = ConnectionClosed CommunicatorChannelInvalidStateError = ChannelInvalidStateError + +CancelledError = kiwipy.CancelledError diff --git a/src/plumpy/rmq/futures.py b/src/plumpy/rmq/futures.py new file mode 100644 index 00000000..897c8147 --- /dev/null +++ b/src/plumpy/rmq/futures.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# mypy: disable-error-code="no-untyped-def, no-untyped-call" +"""Module containing future related methods and classes""" + +import asyncio +import concurrent.futures +from typing import Any + +import kiwipy + +__all__ = ['wrap_to_concurrent_future'] + + +def _convert_future_exc(exc): + exc_class = type(exc) + if exc_class is concurrent.futures.CancelledError: + return asyncio.exceptions.CancelledError(*exc.args) + elif exc_class is concurrent.futures.TimeoutError: + return asyncio.exceptions.TimeoutError(*exc.args) + elif exc_class is concurrent.futures.InvalidStateError: + return asyncio.exceptions.InvalidStateError(*exc.args) + else: + return exc + + +def _set_concurrent_future_state(concurrent, source): + """Copy state from a future to a concurrent.futures.Future.""" + assert source.done() + if source.cancelled(): + concurrent.cancel() + if not concurrent.set_running_or_notify_cancel(): + return + exception = source.exception() + if exception is not None: + concurrent.set_exception(_convert_future_exc(exception)) + else: + result = source.result() + concurrent.set_result(result) + + +def _copy_future_state(source, dest): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert source.done() + if dest.cancelled(): + return + assert not dest.done() + if source.cancelled(): + dest.cancel() + else: + exception = source.exception() + if exception is not None: + dest.set_exception(_convert_future_exc(exception)) + else: + result = source.result() + dest.set_result(result) + + +def _chain_future(source, destination): + """Chain two futures so that when one completes, so does the other. + + The result (or exception) of source will be copied to destination. + If destination is cancelled, source gets cancelled too. + Compatible with both asyncio.Future and concurrent.futures.Future. + """ + if not asyncio.isfuture(source) and not isinstance(source, concurrent.futures.Future): + raise TypeError('A future is required for source argument') + if not asyncio.isfuture(destination) and not isinstance(destination, concurrent.futures.Future): + raise TypeError('A future is required for destination argument') + source_loop = asyncio.Future.get_loop(source) if asyncio.isfuture(source) else None + dest_loop = asyncio.Future.get_loop(destination) if asyncio.isfuture(destination) else None + + def _set_state(future, other): + if asyncio.isfuture(future): + _copy_future_state(other, future) + else: + _set_concurrent_future_state(future, other) + + def _call_check_cancel(destination): + if destination.cancelled(): + if source_loop is None or source_loop is dest_loop: + source.cancel() + else: + source_loop.call_soon_threadsafe(source.cancel) + + def _call_set_state(source): + if destination.cancelled() and dest_loop is not None and dest_loop.is_closed(): + return + if dest_loop is None or dest_loop is source_loop: + _set_state(destination, source) + else: + if dest_loop.is_closed(): + return + dest_loop.call_soon_threadsafe(_set_state, destination, source) + + destination.add_done_callback(_call_check_cancel) + source.add_done_callback(_call_set_state) + + +def wrap_to_concurrent_future(future: asyncio.Future[Any]) -> kiwipy.Future: + """Wrap to concurrent.futures.Future object. (the function is adapted from asyncio.future.wrap_future). + The function `_chain_future`, `_copy_future_state` is from asyncio future module.""" + if isinstance(future, concurrent.futures.Future): + return future + assert asyncio.isfuture(future), f'concurrent.futures.Future is expected, got {future!r}' + + new_future = kiwipy.Future() + _chain_future(future, new_future) + return new_future diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index b48b1c6b..7e67253f 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -23,7 +23,7 @@ cast, ) -import kiwipy +from plumpy.coordinator import Communicator from . import lang, mixins, persistence, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE @@ -128,7 +128,7 @@ def __init__( pid: Optional[PID_TYPE] = None, logger: Optional[logging.Logger] = None, loop: Optional[asyncio.AbstractEventLoop] = None, - communicator: Optional[kiwipy.Communicator] = None, + communicator: Optional[Communicator] = None, ) -> None: super().__init__(inputs=inputs, pid=pid, logger=logger, loop=loop, communicator=communicator) self._stepper: Optional[Stepper] = None diff --git a/tests/rmq/test_communicator.py b/tests/rmq/test_communicator.py index 8d6759be..a2bdae04 100644 --- a/tests/rmq/test_communicator.py +++ b/tests/rmq/test_communicator.py @@ -66,7 +66,7 @@ class TestLoopCommunicator: @pytest.mark.asyncio async def test_broadcast(self, loop_communicator): BROADCAST = {'body': 'present', 'sender': 'Martin', 'subject': 'sup', 'correlation_id': 420} # noqa: N806 - broadcast_future = plumpy.Future() + broadcast_future = asyncio.Future() loop = asyncio.get_event_loop() @@ -85,7 +85,7 @@ def get_broadcast(_comm, body, sender, subject, correlation_id): @pytest.mark.asyncio async def test_broadcast_filter(self, loop_communicator): - broadcast_future = plumpy.Future() + broadcast_future = asyncio.Future() def ignore_broadcast(_comm, body, sender, subject, correlation_id): broadcast_future.set_exception(AssertionError('broadcast received')) @@ -105,7 +105,7 @@ def get_broadcast(_comm, body, sender, subject, correlation_id): @pytest.mark.asyncio async def test_rpc(self, loop_communicator): MSG = 'rpc this' # noqa: N806 - rpc_future = plumpy.Future() + rpc_future = asyncio.Future() loop = asyncio.get_event_loop() @@ -122,7 +122,7 @@ def get_rpc(_comm, msg): @pytest.mark.asyncio async def test_task(self, loop_communicator): TASK = 'task this' # noqa: N806 - task_future = plumpy.Future() + task_future = asyncio.Future() loop = asyncio.get_event_loop() diff --git a/tests/test_processes.py b/tests/test_processes.py index 5d3184f2..3d6b4394 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -7,6 +7,8 @@ import kiwipy import pytest +from plumpy.futures import CancellableAction +from tests import utils import plumpy from plumpy import BundleKeys, Process, ProcessState @@ -537,7 +539,7 @@ def test_pause_in_process(self): class TestPausePlay(plumpy.Process): def run(self): fut = self.pause() - test_case.assertIsInstance(fut, plumpy.Future) + assert isinstance(fut, CancellableAction) loop = asyncio.get_event_loop() @@ -561,7 +563,7 @@ def test_pause_play_in_process(self): class TestPausePlay(plumpy.Process): def run(self): fut = self.pause() - test_case.assertIsInstance(fut, plumpy.Future) + test_case.assertIsInstance(fut, CancellableAction) result = self.play() test_case.assertTrue(result) diff --git a/tests/utils.py b/tests/utils.py index 13abc38c..05290990 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -466,7 +466,7 @@ def run_until_waiting(proc): from plumpy import ProcessState listener = plumpy.ProcessListener() - in_waiting = plumpy.Future() + in_waiting = asyncio.Future() if proc.state == ProcessState.WAITING: in_waiting.set_result(True) @@ -486,7 +486,7 @@ def run_until_paused(proc): """Set up a future that will be resolved when the process is paused""" listener = plumpy.ProcessListener() - paused = plumpy.Future() + paused = asyncio.Future() if proc.paused: paused.set_result(True) From f5e5ec4826e58b67b1767985881c3903b0b20cf7 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Wed, 11 Dec 2024 20:59:07 +0100 Subject: [PATCH 03/22] Move communication into rmq module Move communication to rmq --- docs/source/nitpick-exceptions | 2 +- src/plumpy/__init__.py | 8 +- src/plumpy/futures.py | 2 +- src/plumpy/processes.py | 52 +++-- src/plumpy/rmq/__init__.py | 4 +- src/plumpy/{ => rmq}/communications.py | 4 +- src/plumpy/{ => rmq}/process_comms.py | 182 ++---------------- tests/{ => rmq}/test_communications.py | 4 +- tests/rmq/test_communicator.py | 2 +- tests/rmq/test_process_comms.py | 4 +- ...{test_process_comms.py => test_message.py} | 6 +- 11 files changed, 54 insertions(+), 216 deletions(-) rename src/plumpy/{ => rmq}/communications.py (98%) rename src/plumpy/{ => rmq}/process_comms.py (74%) rename tests/{ => rmq}/test_communications.py (95%) rename tests/{test_process_comms.py => test_message.py} (90%) diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 2f354987..6aa8c345 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -23,7 +23,7 @@ py:class plumpy.base.state_machine.State py:class State py:class Process py:class plumpy.futures.CancellableAction -py:class plumpy.communications.LoopCommunicator +py:class plumpy.rmq.communications.LoopCommunicator py:class plumpy.persistence.PersistedPickle py:class plumpy.utils.AttributesFrozendict py:class plumpy.workchains._FunctionCall diff --git a/src/plumpy/__init__.py b/src/plumpy/__init__.py index adc302ef..8f62edb6 100644 --- a/src/plumpy/__init__.py +++ b/src/plumpy/__init__.py @@ -4,21 +4,20 @@ import logging -from .communications import * from .events import * from .exceptions import * from .futures import * from .loaders import * +from .message import * from .mixins import * from .persistence import * from .ports import * -from .process_comms import * from .process_listener import * from .process_states import * from .processes import * -from .rmq import * from .utils import * from .workchains import * +from .rmq import * __all__ = ( events.__all__ @@ -28,8 +27,7 @@ + futures.__all__ + mixins.__all__ + persistence.__all__ - + communications.__all__ - + process_comms.__all__ + + message.__all__ + process_listener.__all__ + workchains.__all__ + loaders.__all__ diff --git a/src/plumpy/futures.py b/src/plumpy/futures.py index a467f5d8..2f861d64 100644 --- a/src/plumpy/futures.py +++ b/src/plumpy/futures.py @@ -7,7 +7,7 @@ import contextlib from typing import Any, Awaitable, Callable, Generator, Optional -__all__ = ['CancellableAction', 'create_task', 'create_task'] +__all__ = ['CancellableAction', 'create_task', 'create_task', 'capture_exceptions'] class InvalidFutureError(Exception): diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 89b84ef4..5b7c951d 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -39,16 +39,8 @@ import kiwipy import yaml -from . import ( - events, - exceptions, - futures, - persistence, - ports, - process_comms, - process_states, - utils, -) +from . import events, exceptions, message, persistence, ports, process_states, utils +from .futures import capture_exceptions, CancellableAction from .base import state_machine from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event from .base.utils import call_with_super_check, super_check @@ -153,10 +145,10 @@ class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMe _spec_class = ProcessSpec # Default placeholders, will be populated in init() _stepping = False - _pausing: Optional[futures.CancellableAction] = None + _pausing: Optional[CancellableAction] = None _paused: Optional[persistence.SavableFuture] = None - _killing: Optional[futures.CancellableAction] = None - _interrupt_action: Optional[futures.CancellableAction] = None + _killing: Optional[CancellableAction] = None + _interrupt_action: Optional[CancellableAction] = None _closed = False _cleanups: Optional[List[Callable[[], None]]] = None @@ -341,7 +333,7 @@ def init(self) -> None: if not self._future.done(): - def try_killing(future: futures.Future) -> None: + def try_killing(future: asyncio.Future) -> None: if future.cancelled(): if not self.kill('Killed by future being cancelled'): self.logger.warning( @@ -959,15 +951,15 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: MessageType) -> Any: msg, ) - intent = msg[process_comms.INTENT_KEY] + intent = msg[message.INTENT_KEY] - if intent == process_comms.Intent.PLAY: + if intent == message.Intent.PLAY: return self._schedule_rpc(self.play) - if intent == process_comms.Intent.PAUSE: - return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None)) - if intent == process_comms.Intent.KILL: - return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None)) - if intent == process_comms.Intent.STATUS: + if intent == message.Intent.PAUSE: + return self._schedule_rpc(self.pause, msg_text=msg.get(message.MESSAGE_KEY, None)) + if intent == message.Intent.KILL: + return self._schedule_rpc(self.kill, msg_text=msg.get(message.MESSAGE_KEY, None)) + if intent == message.Intent.STATUS: status_info: Dict[str, Any] = {} self.get_status_info(status_info) return status_info @@ -994,7 +986,7 @@ def broadcast_receive( ) # If we get a message we recognise then action it, otherwise ignore - if subject == process_comms.Intent.PLAY: + if subject == message.Intent.PLAY: return self._schedule_rpc(self.play) if subject == process_comms.Intent.PAUSE: return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None)) @@ -1020,7 +1012,7 @@ def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) kiwi_future = concurrent.futures.Future() async def run_callback() -> None: - with kiwipy.capture_exceptions(kiwi_future): + with capture_exceptions(kiwi_future): try: result = callback(*args, **kwargs) except Exception as exc: @@ -1097,7 +1089,7 @@ def transition_failed( ) self.transition_to(new_state) - def pause(self, msg_text: Optional[str] = None) -> Union[bool, futures.CancellableAction]: + def pause(self, msg_text: Optional[str] = None) -> Union[bool, CancellableAction]: """Pause the process. :param msg: an optional message to set as the status. The current status will be saved in the private @@ -1126,7 +1118,7 @@ def pause(self, msg_text: Optional[str] = None) -> Union[bool, futures.Cancellab self._pausing = self._interrupt_action # Try to interrupt the state self._state.interrupt(interrupt_exception) - return cast(futures.CancellableAction, self._interrupt_action) + return cast(CancellableAction, self._interrupt_action) msg = MessageBuilder.pause(msg_text) return self._do_pause(state_msg=msg) @@ -1149,7 +1141,7 @@ def _do_pause(self, state_msg: Optional[MessageType], next_state: Optional[proce return True - def _create_interrupt_action(self, exception: process_states.Interruption) -> futures.CancellableAction: + def _create_interrupt_action(self, exception: process_states.Interruption) -> CancellableAction: """ Create an interrupt action from the corresponding interrupt exception @@ -1159,7 +1151,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu """ if isinstance(exception, process_states.PauseInterruption): do_pause = functools.partial(self._do_pause, exception.msg) - return futures.CancellableAction(do_pause, cookie=exception) + return CancellableAction(do_pause, cookie=exception) if isinstance(exception, process_states.KillInterruption): @@ -1171,11 +1163,11 @@ def do_kill(_next_state: process_states.State) -> Any: finally: self._killing = None - return futures.CancellableAction(do_kill, cookie=exception) + return CancellableAction(do_kill, cookie=exception) raise ValueError(f"Got unknown interruption type '{type(exception)}'") - def _set_interrupt_action(self, new_action: Optional[futures.CancellableAction]) -> None: + def _set_interrupt_action(self, new_action: Optional[CancellableAction]) -> None: """ Set the interrupt action cancelling the current one if it exists :param new_action: The new interrupt action to set @@ -1247,7 +1239,7 @@ def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]: self._set_interrupt_action_from_exception(interrupt_exception) self._killing = self._interrupt_action self._state.interrupt(interrupt_exception) - return cast(futures.CancellableAction, self._interrupt_action) + return cast(CancellableAction, self._interrupt_action) msg = MessageBuilder.kill(msg_text) new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=msg) diff --git a/src/plumpy/rmq/__init__.py b/src/plumpy/rmq/__init__.py index 31d97783..ad0642ca 100644 --- a/src/plumpy/rmq/__init__.py +++ b/src/plumpy/rmq/__init__.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- from .exceptions import * +from .futures import * +from .process_comms import * -__all__ = exceptions.__all__ +__all__ = exceptions.__all__ + communications.__all__ + futures.__all__ + process_comms.__all__ diff --git a/src/plumpy/communications.py b/src/plumpy/rmq/communications.py similarity index 98% rename from src/plumpy/communications.py rename to src/plumpy/rmq/communications.py index 04e39d58..b66e9694 100644 --- a/src/plumpy/communications.py +++ b/src/plumpy/rmq/communications.py @@ -7,8 +7,8 @@ import kiwipy -from . import futures -from .utils import ensure_coroutine +from plumpy import futures +from plumpy.utils import ensure_coroutine __all__ = [ 'Communicator', diff --git a/src/plumpy/process_comms.py b/src/plumpy/rmq/process_comms.py similarity index 74% rename from src/plumpy/process_comms.py rename to src/plumpy/rmq/process_comms.py index 98f7128e..010fd67d 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/rmq/process_comms.py @@ -9,21 +9,28 @@ import kiwipy -from . import communications, futures, loaders, persistence -from .utils import PID_TYPE +from plumpy.message import ( + MESSAGE_KEY, + PAUSE_MSG, + PLAY_MSG, + STATUS_MSG, + KILL_MSG, + Intent, + create_continue_body, + create_create_body, + create_launch_body, +) + +from plumpy import loaders +from plumpy.utils import PID_TYPE __all__ = [ 'MessageBuilder', 'ProcessLauncher', 'RemoteProcessController', 'RemoteProcessThreadController', - 'create_continue_body', - 'create_launch_body', ] -if TYPE_CHECKING: - from .processes import Process - ProcessResult = Any ProcessStatus = Any @@ -498,164 +505,3 @@ def task_send(self, message: Any, no_reply: bool = False) -> Optional[Any]: :return: the response from the remote side (if no_reply=False) """ return self._communicator.task_send(message, no_reply=no_reply) - - -class ProcessLauncher: - """ - Takes incoming task messages and uses them to launch processes. - - Expected format of task: - - For launch:: - - { - 'task': - 'process_class': - 'args': - 'kwargs': . - 'nowait': True or False - } - - For continue:: - - { - 'task': - 'pid': - 'nowait': True or False - } - """ - - def __init__( - self, - loop: Optional[asyncio.AbstractEventLoop] = None, - persister: Optional[persistence.Persister] = None, - load_context: Optional[persistence.LoadSaveContext] = None, - loader: Optional[loaders.ObjectLoader] = None, - ) -> None: - self._loop = loop - self._persister = persister - self._load_context = load_context if load_context is not None else persistence.LoadSaveContext() - - if loader is not None: - self._loader = loader - self._load_context = self._load_context.copyextend(loader=loader) - else: - self._loader = loaders.get_object_loader() - - async def __call__(self, communicator: kiwipy.Communicator, task: Dict[str, Any]) -> Union[PID_TYPE, ProcessResult]: - """ - Receive a task. - :param task: The task message - """ - task_type = task[TASK_KEY] - if task_type == LAUNCH_TASK: - return await self._launch(communicator, **task.get(TASK_ARGS, {})) - if task_type == CONTINUE_TASK: - return await self._continue(communicator, **task.get(TASK_ARGS, {})) - if task_type == CREATE_TASK: - return await self._create(communicator, **task.get(TASK_ARGS, {})) - - raise communications.TaskRejected - - async def _launch( - self, - _communicator: kiwipy.Communicator, - process_class: str, - persist: bool, - nowait: bool, - init_args: Optional[Sequence[Any]] = None, - init_kwargs: Optional[Dict[str, Any]] = None, - ) -> Union[PID_TYPE, ProcessResult]: - """ - Launch the process - - :param _communicator: the communicator - :param process_class: the process class to launch - :param persist: should the process be persisted - :param nowait: if True only return when the process finishes - :param init_args: positional arguments to the process constructor - :param init_kwargs: keyword arguments to the process constructor - :return: the pid of the created process or the outputs (if nowait=False) - """ - if persist and not self._persister: - raise communications.TaskRejected('Cannot persist process, no persister') - - if init_args is None: - init_args = () - if init_kwargs is None: - init_kwargs = {} - - proc_class = self._loader.load_object(process_class) - proc = proc_class(*init_args, **init_kwargs) - if persist and self._persister is not None: - self._persister.save_checkpoint(proc) - - if nowait: - # XXX: can return a reference and gracefully use task to cancel itself when the upper call stack fails - asyncio.ensure_future(proc.step_until_terminated()) # noqa: RUF006 - return proc.pid - - await proc.step_until_terminated() - - return proc.future().result() - - async def _continue( - self, _communicator: kiwipy.Communicator, pid: 'PID_TYPE', nowait: bool, tag: Optional[str] = None - ) -> Union[PID_TYPE, ProcessResult]: - """ - Continue the process - - :param _communicator: the communicator - :param pid: the pid of the process to continue - :param nowait: if True don't wait for the process to complete - :param tag: the checkpoint tag to continue from - """ - if not self._persister: - LOGGER.warning('rejecting task: cannot continue process<%d> because no persister is available', pid) - raise communications.TaskRejected('Cannot continue process, no persister') - - # Do not catch exceptions here, because if these operations fail, the continue task should except and bubble up - saved_state = self._persister.load_checkpoint(pid, tag) - proc = cast('Process', saved_state.unbundle(self._load_context)) - - if nowait: - # XXX: can return a reference and gracefully use task to cancel itself when the upper call stack fails - asyncio.ensure_future(proc.step_until_terminated()) # noqa: RUF006 - return proc.pid - - await proc.step_until_terminated() - - return proc.future().result() - - async def _create( - self, - _communicator: kiwipy.Communicator, - process_class: str, - persist: bool, - init_args: Optional[Sequence[Any]] = None, - init_kwargs: Optional[Dict[str, Any]] = None, - ) -> 'PID_TYPE': - """ - Create the process - - :param _communicator: the communicator - :param process_class: the process class to create - :param persist: should the process be persisted - :param init_args: positional arguments to the process constructor - :param init_kwargs: keyword arguments to the process constructor - :return: the pid of the created process - """ - if persist and not self._persister: - raise communications.TaskRejected('Cannot persist process, no persister') - - if init_args is None: - init_args = () - if init_kwargs is None: - init_kwargs = {} - - proc_class = self._loader.load_object(process_class) - proc = proc_class(*init_args, **init_kwargs) - if persist and self._persister is not None: - self._persister.save_checkpoint(proc) - - return proc.pid diff --git a/tests/test_communications.py b/tests/rmq/test_communications.py similarity index 95% rename from tests/test_communications.py rename to tests/rmq/test_communications.py index f7e04255..63813bdc 100644 --- a/tests/test_communications.py +++ b/tests/rmq/test_communications.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- -"""Tests for the :mod:`plumpy.communications` module.""" +"""Tests for the :mod:`plumpy.rmq.communications` module.""" import pytest from kiwipy import CommunicatorHelper -from plumpy.communications import LoopCommunicator +from plumpy.rmq.communications import LoopCommunicator class Subscriber: diff --git a/tests/rmq/test_communicator.py b/tests/rmq/test_communicator.py index a2bdae04..ad067c25 100644 --- a/tests/rmq/test_communicator.py +++ b/tests/rmq/test_communicator.py @@ -13,7 +13,7 @@ from kiwipy import BroadcastFilter, rmq import plumpy -from plumpy import communications, process_comms +from plumpy.rmq import communications, process_comms from .. import utils diff --git a/tests/rmq/test_process_comms.py b/tests/rmq/test_process_comms.py index 7a03fac4..454c0787 100644 --- a/tests/rmq/test_process_comms.py +++ b/tests/rmq/test_process_comms.py @@ -7,8 +7,8 @@ from kiwipy import rmq import plumpy -import plumpy.communications -from plumpy import process_comms +from plumpy.message import KILL_MSG, MESSAGE_KEY +from plumpy.rmq import process_comms from .. import utils diff --git a/tests/test_process_comms.py b/tests/test_message.py similarity index 90% rename from tests/test_process_comms.py rename to tests/test_message.py index 44947230..82951afd 100644 --- a/tests/test_process_comms.py +++ b/tests/test_message.py @@ -2,7 +2,7 @@ import pytest import plumpy -from plumpy import process_comms +from plumpy import message from tests import utils @@ -37,7 +37,7 @@ async def test_continue(): del process process = None - result = await launcher._continue(None, **plumpy.create_continue_body(pid)[process_comms.TASK_ARGS]) + result = await launcher._continue(None, **plumpy.create_continue_body(pid)[message.TASK_ARGS]) assert result == utils.DummyProcess.EXPECTED_OUTPUTS @@ -51,5 +51,5 @@ async def test_loader_is_used(): launcher = plumpy.ProcessLauncher(persister=persister, loader=loader) continue_task = plumpy.create_continue_body(proc.pid) - result = await launcher._continue(None, **continue_task[process_comms.TASK_ARGS]) + result = await launcher._continue(None, **continue_task[message.TASK_ARGS]) assert result == utils.DummyProcess.EXPECTED_OUTPUTS From c2c9a65a5a53b36067aa354d6949834ef92342bd Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Sat, 14 Dec 2024 14:27:28 +0100 Subject: [PATCH 04/22] Move TaskRejectError as the common exception for task launch --- src/plumpy/exceptions.py | 3 +++ src/plumpy/message.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/src/plumpy/exceptions.py b/src/plumpy/exceptions.py index 6f0c75a4..1e6f3b26 100644 --- a/src/plumpy/exceptions.py +++ b/src/plumpy/exceptions.py @@ -38,3 +38,6 @@ class PersistenceError(Exception): class ClosedError(Exception): """Raised when an mutable operation is attempted on a closed process""" + +class TaskRejectedError(Exception): + """ A task was rejected by the coordinacor""" diff --git a/src/plumpy/message.py b/src/plumpy/message.py index 47586d21..b5e1348a 100644 --- a/src/plumpy/message.py +++ b/src/plumpy/message.py @@ -8,6 +8,8 @@ from plumpy.coordinator import Communicator from plumpy.exceptions import PersistenceError, TaskRejectedError +from plumpy.exceptions import PersistenceError, TaskRejectedError + from . import loaders, persistence from .utils import PID_TYPE From 4dae773dd91b9081df6b7a32dbf255f408156981 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Sat, 14 Dec 2024 14:39:54 +0100 Subject: [PATCH 05/22] Remove useless communicator param passed to ProcessLaunch __call__ --- tests/test_message.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_message.py b/tests/test_message.py index 82951afd..0a6ee96c 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -37,7 +37,7 @@ async def test_continue(): del process process = None - result = await launcher._continue(None, **plumpy.create_continue_body(pid)[message.TASK_ARGS]) + result = await launcher._continue(**plumpy.create_continue_body(pid)[message.TASK_ARGS]) assert result == utils.DummyProcess.EXPECTED_OUTPUTS @@ -51,5 +51,5 @@ async def test_loader_is_used(): launcher = plumpy.ProcessLauncher(persister=persister, loader=loader) continue_task = plumpy.create_continue_body(proc.pid) - result = await launcher._continue(None, **continue_task[message.TASK_ARGS]) + result = await launcher._continue(**continue_task[message.TASK_ARGS]) assert result == utils.DummyProcess.EXPECTED_OUTPUTS From 56c18d4576e4b452650e05034e2536d6a30a4e14 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Sat, 14 Dec 2024 21:06:06 +0100 Subject: [PATCH 06/22] Forming Communicator protocol --- src/plumpy/coordinator.py | 21 ++++++++++++++++++++ src/plumpy/processes.py | 42 ++++++++++++++++++++++++--------------- 2 files changed, 47 insertions(+), 16 deletions(-) create mode 100644 src/plumpy/coordinator.py diff --git a/src/plumpy/coordinator.py b/src/plumpy/coordinator.py new file mode 100644 index 00000000..214fc18f --- /dev/null +++ b/src/plumpy/coordinator.py @@ -0,0 +1,21 @@ +from typing import Any, Callable, Protocol + +RpcSubscriber = Callable[['Communicator', Any], Any] +BroadcastSubscriber = Callable[['Communicator', Any, Any, Any, Any], Any] + +class Communicator(Protocol): + + def add_rpc_subscriber(self, subscriber: RpcSubscriber, identifier=None) -> Any: + ... + + def add_broadcast_subscriber(self, subscriber: BroadcastSubscriber, identifier=None) -> Any: + ... + + def remove_rpc_subscriber(self, identifier): + ... + + def remove_broadcast_subscriber(self, identifier): + ... + + def broadcast_send(self, body, sender=None, subject=None, correlation_id=None) -> bool: + ... diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 5b7c951d..f12874ed 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -3,6 +3,7 @@ import abc import asyncio +import concurrent.futures import contextlib import copy import enum @@ -31,6 +32,8 @@ cast, ) +from plumpy.coordinator import Communicator + try: from aiocontextvars import ContextVar except ModuleNotFoundError: @@ -264,7 +267,7 @@ def __init__( pid: Optional[PID_TYPE] = None, logger: Optional[logging.Logger] = None, loop: Optional[asyncio.AbstractEventLoop] = None, - communicator: Optional[kiwipy.Communicator] = None, + communicator: Optional[Communicator] = None, ) -> None: """ The signature of the constructor should not be changed by subclassing processes. @@ -317,19 +320,17 @@ def init(self) -> None: try: identifier = self._communicator.add_rpc_subscriber(self.message_receive, identifier=str(self.pid)) self.add_cleanup(functools.partial(self._communicator.remove_rpc_subscriber, identifier)) - except kiwipy.TimeoutError: + except concurrent.futures.TimeoutError: self.logger.exception('Process<%s>: failed to register as an RPC subscriber', self.pid) try: # filter out state change broadcasts + # TODO: pattern filter should be moved to add_broadcast_subscriber. subscriber = kiwipy.BroadcastFilter(self.broadcast_receive, subject=re.compile(r'^(?!state_changed).*')) identifier = self._communicator.add_broadcast_subscriber(subscriber, identifier=str(self.pid)) self.add_cleanup(functools.partial(self._communicator.remove_broadcast_subscriber, identifier)) - except kiwipy.TimeoutError: - self.logger.exception( - 'Process<%s>: failed to register as a broadcast subscriber', - self.pid, - ) + except concurrent.futures.TimeoutError: + self.logger.exception('Process<%s>: failed to register as a broadcast subscriber', self.pid) if not self._future.done(): @@ -726,8 +727,6 @@ def on_entering(self, state: process_states.State) -> None: call_with_super_check(self.on_except, state.get_exc_info()) # type: ignore def on_entered(self, from_state: Optional[process_states.State]) -> None: - from plumpy.rmq.exceptions import CommunicatorChannelInvalidStateError, CommunicatorConnectionClosed - # Map these onto direct functions that the subclass can implement state_label = self._state.LABEL if state_label == process_states.ProcessState.RUNNING: @@ -742,6 +741,8 @@ def on_entered(self, from_state: Optional[process_states.State]) -> None: call_with_super_check(self.on_killed) if self._communicator and isinstance(self.state, enum.Enum): + from plumpy.rmq.exceptions import CommunicatorChannelInvalidStateError, CommunicatorConnectionClosed + from_label = cast(enum.Enum, from_state.LABEL).value if from_state is not None else None subject = f'state_changed.{from_label}.{self.state.value}' self.logger.info('Process<%s>: Broadcasting state change: %s', self.pid, subject) @@ -750,7 +751,7 @@ def on_entered(self, from_state: Optional[process_states.State]) -> None: except (CommunicatorConnectionClosed, CommunicatorChannelInvalidStateError): message = 'Process<%s>: no connection available to broadcast state change from %s to %s' self.logger.warning(message, self.pid, from_label, self.state.value) - except kiwipy.TimeoutError: + except concurrent.futures.TimeoutError: message = 'Process<%s>: sending broadcast of state change from %s to %s timed out' self.logger.warning(message, self.pid, from_label, self.state.value) @@ -936,7 +937,7 @@ def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> Non # region Communication - def message_receive(self, _comm: kiwipy.Communicator, msg: MessageType) -> Any: + def message_receive(self, _comm: Communicator, msg: MessageType) -> Any: """ Coroutine called when the process receives a message from the communicator @@ -984,15 +985,24 @@ def broadcast_receive( _comm, msg, ) - # If we get a message we recognise then action it, otherwise ignore + fn = None if subject == message.Intent.PLAY: - return self._schedule_rpc(self.play) - if subject == process_comms.Intent.PAUSE: + fn = self._schedule_rpc(self.play) + elif subject == message.Intent.PAUSE: return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None)) - if subject == process_comms.Intent.KILL: + elif subject == message.Intent.KILL: return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None)) - return None + + if fn is None: + self.logger.warning( + "Process<%s>: received unsupported broadcast message '%s'.", + self.pid, + subject, + ) + return None + + return fn def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> concurrent.futures.Future: """ From 2d9fdb95b1e2efd07a4f0b93ec8e4f7bb0257494 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Sun, 15 Dec 2024 00:07:55 +0100 Subject: [PATCH 07/22] Remove kiwipy/rmq dependencies of process module --- src/plumpy/coordinator.py | 22 ++-- src/plumpy/exceptions.py | 3 +- src/plumpy/futures.py | 2 +- src/plumpy/message.py | 40 ++++++++ src/plumpy/process_states.py | 2 +- src/plumpy/processes.py | 12 +-- src/plumpy/rmq/communications.py | 4 +- src/plumpy/rmq/process_comms.py | 168 +------------------------------ tests/rmq/test_communications.py | 2 +- tests/rmq/test_communicator.py | 5 +- tests/test_processes.py | 3 +- tests/utils.py | 2 +- 12 files changed, 73 insertions(+), 192 deletions(-) diff --git a/src/plumpy/coordinator.py b/src/plumpy/coordinator.py index 214fc18f..1daaf1f8 100644 --- a/src/plumpy/coordinator.py +++ b/src/plumpy/coordinator.py @@ -1,21 +1,19 @@ -from typing import Any, Callable, Protocol +# -*- coding: utf-8 -*- +from typing import Any, Callable, Pattern, Protocol RpcSubscriber = Callable[['Communicator', Any], Any] BroadcastSubscriber = Callable[['Communicator', Any, Any, Any, Any], Any] -class Communicator(Protocol): - def add_rpc_subscriber(self, subscriber: RpcSubscriber, identifier=None) -> Any: - ... +class Communicator(Protocol): + def add_rpc_subscriber(self, subscriber: RpcSubscriber, identifier=None) -> Any: ... - def add_broadcast_subscriber(self, subscriber: BroadcastSubscriber, identifier=None) -> Any: - ... + def add_broadcast_subscriber( + self, subscriber: BroadcastSubscriber, subject_filter: str | Pattern[str] | None = None, identifier=None + ) -> Any: ... - def remove_rpc_subscriber(self, identifier): - ... + def remove_rpc_subscriber(self, identifier): ... - def remove_broadcast_subscriber(self, identifier): - ... + def remove_broadcast_subscriber(self, identifier): ... - def broadcast_send(self, body, sender=None, subject=None, correlation_id=None) -> bool: - ... + def broadcast_send(self, body, sender=None, subject=None, correlation_id=None) -> bool: ... diff --git a/src/plumpy/exceptions.py b/src/plumpy/exceptions.py index 1e6f3b26..9dca8fdb 100644 --- a/src/plumpy/exceptions.py +++ b/src/plumpy/exceptions.py @@ -39,5 +39,6 @@ class PersistenceError(Exception): class ClosedError(Exception): """Raised when an mutable operation is attempted on a closed process""" + class TaskRejectedError(Exception): - """ A task was rejected by the coordinacor""" + """A task was rejected by the coordinacor""" diff --git a/src/plumpy/futures.py b/src/plumpy/futures.py index 2f861d64..01be3951 100644 --- a/src/plumpy/futures.py +++ b/src/plumpy/futures.py @@ -7,7 +7,7 @@ import contextlib from typing import Any, Awaitable, Callable, Generator, Optional -__all__ = ['CancellableAction', 'create_task', 'create_task', 'capture_exceptions'] +__all__ = ['CancellableAction', 'capture_exceptions', 'create_task', 'create_task'] class InvalidFutureError(Exception): diff --git a/src/plumpy/message.py b/src/plumpy/message.py index b5e1348a..c63748f3 100644 --- a/src/plumpy/message.py +++ b/src/plumpy/message.py @@ -28,6 +28,7 @@ INTENT_KEY = 'intent' MESSAGE_KEY = 'message' +FORCE_KILL_KEY = 'force_kill' class Intent: @@ -62,6 +63,45 @@ class Intent: LOGGER = logging.getLogger(__name__) +MessageType = Dict[str, Any] + + +class MessageBuilder: + """MessageBuilder will construct different messages that can passing over communicator.""" + + @classmethod + def play(cls, text: str | None = None) -> MessageType: + """The play message send over communicator.""" + return { + INTENT_KEY: Intent.PLAY, + MESSAGE_KEY: text, + } + + @classmethod + def pause(cls, text: str | None = None) -> MessageType: + """The pause message send over communicator.""" + return { + INTENT_KEY: Intent.PAUSE, + MESSAGE_KEY: text, + } + + @classmethod + def kill(cls, text: str | None = None, force_kill: bool = False) -> MessageType: + """The kill message send over communicator.""" + return { + INTENT_KEY: Intent.KILL, + MESSAGE_KEY: text, + FORCE_KILL_KEY: force_kill, + } + + @classmethod + def status(cls, text: str | None = None) -> MessageType: + """The status message send over communicator.""" + return { + INTENT_KEY: Intent.STATUS, + MESSAGE_KEY: text, + } + def create_launch_body( process_class: str, diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 931dbc5e..723292bf 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -10,7 +10,7 @@ import yaml from yaml.loader import Loader -from plumpy.process_comms import MessageBuilder, MessageType +from plumpy.message import MessageBuilder, MessageType try: import tblib diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index f12874ed..ef1f1f58 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -39,16 +39,15 @@ except ModuleNotFoundError: from contextvars import ContextVar -import kiwipy import yaml from . import events, exceptions, message, persistence, ports, process_states, utils -from .futures import capture_exceptions, CancellableAction from .base import state_machine from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event from .base.utils import call_with_super_check, super_check from .event_helper import EventHelper -from .process_comms import MESSAGE_TEXT_KEY, MessageBuilder, MessageType +from .futures import CancellableAction, capture_exceptions +from .message import MESSAGE_TEXT_KEY, MessageBuilder, MessageType from .process_listener import ProcessListener from .process_spec import ProcessSpec from .utils import PID_TYPE, SAVED_STATE_TYPE, protected @@ -325,9 +324,9 @@ def init(self) -> None: try: # filter out state change broadcasts - # TODO: pattern filter should be moved to add_broadcast_subscriber. - subscriber = kiwipy.BroadcastFilter(self.broadcast_receive, subject=re.compile(r'^(?!state_changed).*')) - identifier = self._communicator.add_broadcast_subscriber(subscriber, identifier=str(self.pid)) + identifier = self._communicator.add_broadcast_subscriber( + self.broadcast_receive, subject_filter=re.compile(r'^(?!state_changed).*'), identifier=str(self.pid) + ) self.add_cleanup(functools.partial(self._communicator.remove_broadcast_subscriber, identifier)) except concurrent.futures.TimeoutError: self.logger.exception('Process<%s>: failed to register as a broadcast subscriber', self.pid) @@ -741,6 +740,7 @@ def on_entered(self, from_state: Optional[process_states.State]) -> None: call_with_super_check(self.on_killed) if self._communicator and isinstance(self.state, enum.Enum): + # FIXME: move all to `coordinator.broadcast()` call and in rmq implement coordinator from plumpy.rmq.exceptions import CommunicatorChannelInvalidStateError, CommunicatorConnectionClosed from_label = cast(enum.Enum, from_state.LABEL).value if from_state is not None else None diff --git a/src/plumpy/rmq/communications.py b/src/plumpy/rmq/communications.py index b66e9694..5e526b23 100644 --- a/src/plumpy/rmq/communications.py +++ b/src/plumpy/rmq/communications.py @@ -130,10 +130,10 @@ def remove_task_subscriber(self, identifier: 'ID_TYPE') -> None: return self._communicator.remove_task_subscriber(identifier) def add_broadcast_subscriber( - self, subscriber: 'BroadcastSubscriber', identifier: Optional['ID_TYPE'] = None + self, subscriber: 'BroadcastSubscriber', subject_filter=None, identifier: Optional['ID_TYPE'] = None ) -> 'ID_TYPE': converted = convert_to_comm(subscriber, self._loop) - return self._communicator.add_broadcast_subscriber(converted, identifier) + return self._communicator.add_broadcast_subscriber(converted, subject_filter, identifier) def remove_broadcast_subscriber(self, identifier: 'ID_TYPE') -> None: return self._communicator.remove_broadcast_subscriber(identifier) diff --git a/src/plumpy/rmq/process_comms.py b/src/plumpy/rmq/process_comms.py index 010fd67d..f7903d6e 100644 --- a/src/plumpy/rmq/process_comms.py +++ b/src/plumpy/rmq/process_comms.py @@ -4,29 +4,22 @@ from __future__ import annotations import asyncio -import logging -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast +from typing import Any, Dict, Optional, Sequence, Union import kiwipy +from plumpy import loaders from plumpy.message import ( - MESSAGE_KEY, - PAUSE_MSG, - PLAY_MSG, - STATUS_MSG, - KILL_MSG, Intent, + MessageBuilder, + MessageType, create_continue_body, create_create_body, create_launch_body, ) - -from plumpy import loaders from plumpy.utils import PID_TYPE __all__ = [ - 'MessageBuilder', - 'ProcessLauncher', 'RemoteProcessController', 'RemoteProcessThreadController', ] @@ -34,159 +27,6 @@ ProcessResult = Any ProcessStatus = Any -INTENT_KEY = 'intent' -MESSAGE_TEXT_KEY = 'message' -FORCE_KILL_KEY = 'force_kill' - - -class Intent: - """Intent constants for a process message""" - - PLAY: str = 'play' - PAUSE: str = 'pause' - KILL: str = 'kill' - STATUS: str = 'status' - - -MessageType = Dict[str, Any] - - -class MessageBuilder: - """MessageBuilder will construct different messages that can passing over communicator.""" - - @classmethod - def play(cls, text: str | None = None) -> MessageType: - """The play message send over communicator.""" - return { - INTENT_KEY: Intent.PLAY, - MESSAGE_TEXT_KEY: text, - } - - @classmethod - def pause(cls, text: str | None = None) -> MessageType: - """The pause message send over communicator.""" - return { - INTENT_KEY: Intent.PAUSE, - MESSAGE_TEXT_KEY: text, - } - - @classmethod - def kill(cls, text: str | None = None, force_kill: bool = False) -> MessageType: - """The kill message send over communicator.""" - return { - INTENT_KEY: Intent.KILL, - MESSAGE_TEXT_KEY: text, - FORCE_KILL_KEY: force_kill, - } - - @classmethod - def status(cls, text: str | None = None) -> MessageType: - """The status message send over communicator.""" - return { - INTENT_KEY: Intent.STATUS, - MESSAGE_TEXT_KEY: text, - } - - -TASK_KEY = 'task' -TASK_ARGS = 'args' -PERSIST_KEY = 'persist' -# Launch -PROCESS_CLASS_KEY = 'process_class' -ARGS_KEY = 'init_args' -KWARGS_KEY = 'init_kwargs' -NOWAIT_KEY = 'nowait' -# Continue -PID_KEY = 'pid' -TAG_KEY = 'tag' -# Task types -LAUNCH_TASK = 'launch' -CONTINUE_TASK = 'continue' -CREATE_TASK = 'create' - -LOGGER = logging.getLogger(__name__) - - -def create_launch_body( - process_class: str, - init_args: Optional[Sequence[Any]] = None, - init_kwargs: Optional[Dict[str, Any]] = None, - persist: bool = False, - loader: Optional[loaders.ObjectLoader] = None, - nowait: bool = True, -) -> Dict[str, Any]: - """ - Create a message body for the launch action - - :param process_class: the class of the process to launch - :param init_args: any initialisation positional arguments - :param init_kwargs: any initialisation keyword arguments - :param persist: persist this process if True, otherwise don't - :param loader: the loader to use to load the persisted process - :param nowait: wait for the process to finish before completing the task, otherwise just return the PID - :return: a dictionary with the body of the message to launch the process - :rtype: dict - """ - if loader is None: - loader = loaders.get_object_loader() - - msg_body = { - TASK_KEY: LAUNCH_TASK, - TASK_ARGS: { - PROCESS_CLASS_KEY: loader.identify_object(process_class), - PERSIST_KEY: persist, - NOWAIT_KEY: nowait, - ARGS_KEY: init_args, - KWARGS_KEY: init_kwargs, - }, - } - return msg_body - - -def create_continue_body(pid: 'PID_TYPE', tag: Optional[str] = None, nowait: bool = False) -> Dict[str, Any]: - """ - Create a message body to continue an existing process - :param pid: the pid of the existing process - :param tag: the optional persistence tag - :param nowait: wait for the process to finish before completing the task, otherwise just return the PID - :return: a dictionary with the body of the message to continue the process - - """ - msg_body = {TASK_KEY: CONTINUE_TASK, TASK_ARGS: {PID_KEY: pid, NOWAIT_KEY: nowait, TAG_KEY: tag}} - return msg_body - - -def create_create_body( - process_class: str, - init_args: Optional[Sequence[Any]] = None, - init_kwargs: Optional[Dict[str, Any]] = None, - persist: bool = False, - loader: Optional[loaders.ObjectLoader] = None, -) -> Dict[str, Any]: - """ - Create a message body to create a new process - :param process_class: the class of the process to launch - :param init_args: any initialisation positional arguments - :param init_kwargs: any initialisation keyword arguments - :param persist: persist this process if True, otherwise don't - :param loader: the loader to use to load the persisted process - :return: a dictionary with the body of the message to launch the process - - """ - if loader is None: - loader = loaders.get_object_loader() - - msg_body = { - TASK_KEY: CREATE_TASK, - TASK_ARGS: { - PROCESS_CLASS_KEY: loader.identify_object(process_class), - PERSIST_KEY: persist, - ARGS_KEY: init_args, - KWARGS_KEY: init_kwargs, - }, - } - return msg_body - class RemoteProcessController: """ diff --git a/tests/rmq/test_communications.py b/tests/rmq/test_communications.py index 63813bdc..00b7f1c6 100644 --- a/tests/rmq/test_communications.py +++ b/tests/rmq/test_communications.py @@ -56,7 +56,7 @@ def test_add_broadcast_subscriber(loop_communicator, subscriber): assert loop_communicator.add_broadcast_subscriber(subscriber) is not None identifier = 'identifier' - assert loop_communicator.add_broadcast_subscriber(subscriber, identifier) == identifier + assert loop_communicator.add_broadcast_subscriber(subscriber, identifier=identifier) == identifier def test_remove_broadcast_subscriber(loop_communicator, subscriber): diff --git a/tests/rmq/test_communicator.py b/tests/rmq/test_communicator.py index ad067c25..e9d20db5 100644 --- a/tests/rmq/test_communicator.py +++ b/tests/rmq/test_communicator.py @@ -7,6 +7,7 @@ import tempfile import uuid +from kiwipy.rmq.communicator import kiwipy import pytest import shortuuid import yaml @@ -84,7 +85,7 @@ def get_broadcast(_comm, body, sender, subject, correlation_id): assert result == BROADCAST @pytest.mark.asyncio - async def test_broadcast_filter(self, loop_communicator): + async def test_broadcast_filter(self, loop_communicator: kiwipy.Communicator): broadcast_future = asyncio.Future() def ignore_broadcast(_comm, body, sender, subject, correlation_id): @@ -93,7 +94,7 @@ def ignore_broadcast(_comm, body, sender, subject, correlation_id): def get_broadcast(_comm, body, sender, subject, correlation_id): broadcast_future.set_result(True) - loop_communicator.add_broadcast_subscriber(BroadcastFilter(ignore_broadcast, subject='other')) + loop_communicator.add_broadcast_subscriber(ignore_broadcast, subject_filter='other') loop_communicator.add_broadcast_subscriber(get_broadcast) loop_communicator.broadcast_send( **{'body': 'present', 'sender': 'Martin', 'subject': 'sup', 'correlation_id': 420} diff --git a/tests/test_processes.py b/tests/test_processes.py index 3d6b4394..0b38287d 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -12,7 +12,7 @@ import plumpy from plumpy import BundleKeys, Process, ProcessState -from plumpy.process_comms import MESSAGE_TEXT_KEY, MessageBuilder +from plumpy.message import MessageBuilder from plumpy.utils import AttributesFrozendict from tests import utils @@ -1066,6 +1066,7 @@ def test_paused(self): self.assertSetEqual(events_tester.called, events_tester.expected_events) def test_broadcast(self): + # FIXME: here I need a mock test communicator = kiwipy.LocalCommunicator() messages = [] diff --git a/tests/utils.py b/tests/utils.py index 05290990..123d6e72 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -8,7 +8,7 @@ import plumpy from plumpy import persistence, process_states, processes, utils -from plumpy.process_comms import MessageBuilder +from plumpy.message import MessageBuilder Snapshot = collections.namedtuple('Snapshot', ['state', 'bundle', 'outputs']) From 9d6655e98e436945ec285a86675799ea29370d87 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Tue, 17 Dec 2024 12:17:10 +0100 Subject: [PATCH 08/22] Interface change from communicator -> coordinator --- src/plumpy/coordinator.py | 13 +++++++++++++ src/plumpy/message.py | 4 ++-- src/plumpy/processes.py | 32 ++++++++++++++++---------------- src/plumpy/workchains.py | 6 +++--- tests/rmq/test_process_comms.py | 24 ++++++++++++------------ tests/test_processes.py | 2 +- 6 files changed, 47 insertions(+), 34 deletions(-) diff --git a/src/plumpy/coordinator.py b/src/plumpy/coordinator.py index 1daaf1f8..cd66a883 100644 --- a/src/plumpy/coordinator.py +++ b/src/plumpy/coordinator.py @@ -17,3 +17,16 @@ def remove_rpc_subscriber(self, identifier): ... def remove_broadcast_subscriber(self, identifier): ... def broadcast_send(self, body, sender=None, subject=None, correlation_id=None) -> bool: ... + +class Coordinator(Protocol): + def add_rpc_subscriber(self, subscriber: RpcSubscriber, identifier=None) -> Any: ... + + def add_broadcast_subscriber( + self, subscriber: BroadcastSubscriber, subject_filter: str | Pattern[str] | None = None, identifier=None + ) -> Any: ... + + def remove_rpc_subscriber(self, identifier): ... + + def remove_broadcast_subscriber(self, identifier): ... + + def broadcast_send(self, body, sender=None, subject=None, correlation_id=None) -> bool: ... diff --git a/src/plumpy/message.py b/src/plumpy/message.py index c63748f3..2d52b048 100644 --- a/src/plumpy/message.py +++ b/src/plumpy/message.py @@ -5,7 +5,7 @@ import logging from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast -from plumpy.coordinator import Communicator +from plumpy.coordinator import Coordinator from plumpy.exceptions import PersistenceError, TaskRejectedError from plumpy.exceptions import PersistenceError, TaskRejectedError @@ -226,7 +226,7 @@ def __init__( else: self._loader = loaders.get_object_loader() - async def __call__(self, communicator: Communicator, task: Dict[str, Any]) -> Union[PID_TYPE, Any]: + async def __call__(self, coordinator: Coordinator, task: Dict[str, Any]) -> Union[PID_TYPE, Any]: """ Receive a task. :param task: The task message diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index ef1f1f58..8753c20b 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -32,7 +32,7 @@ cast, ) -from plumpy.coordinator import Communicator +from plumpy.coordinator import Coordinator try: from aiocontextvars import ContextVar @@ -266,7 +266,7 @@ def __init__( pid: Optional[PID_TYPE] = None, logger: Optional[logging.Logger] = None, loop: Optional[asyncio.AbstractEventLoop] = None, - communicator: Optional[Communicator] = None, + coordinator: Optional[Coordinator] = None, ) -> None: """ The signature of the constructor should not be changed by subclassing processes. @@ -305,7 +305,7 @@ def __init__( self._future = persistence.SavableFuture(loop=self._loop) self._event_helper = EventHelper(ProcessListener) self._logger = logger - self._communicator = communicator + self._coordinator = coordinator @super_check def init(self) -> None: @@ -315,19 +315,19 @@ def init(self) -> None: """ self._cleanups = [] # a list of functions to be ran on terminated - if self._communicator is not None: + if self._coordinator is not None: try: - identifier = self._communicator.add_rpc_subscriber(self.message_receive, identifier=str(self.pid)) - self.add_cleanup(functools.partial(self._communicator.remove_rpc_subscriber, identifier)) + identifier = self._coordinator.add_rpc_subscriber(self.message_receive, identifier=str(self.pid)) + self.add_cleanup(functools.partial(self._coordinator.remove_rpc_subscriber, identifier)) except concurrent.futures.TimeoutError: self.logger.exception('Process<%s>: failed to register as an RPC subscriber', self.pid) try: # filter out state change broadcasts - identifier = self._communicator.add_broadcast_subscriber( + identifier = self._coordinator.add_broadcast_subscriber( self.broadcast_receive, subject_filter=re.compile(r'^(?!state_changed).*'), identifier=str(self.pid) ) - self.add_cleanup(functools.partial(self._communicator.remove_broadcast_subscriber, identifier)) + self.add_cleanup(functools.partial(self._coordinator.remove_broadcast_subscriber, identifier)) except concurrent.futures.TimeoutError: self.logger.exception('Process<%s>: failed to register as a broadcast subscriber', self.pid) @@ -448,7 +448,7 @@ def launch( pid=pid, logger=logger, loop=self.loop, - communicator=self._communicator, + coordinator=self._coordinator, ) self.loop.create_task(process.step_until_terminated()) return process @@ -644,7 +644,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi self._future = persistence.SavableFuture() self._event_helper = EventHelper(ProcessListener) self._logger = None - self._communicator = None + self._coordinator = None if 'loop' in load_context: self._loop = load_context.loop @@ -653,8 +653,8 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi self._state: process_states.State = self.recreate_state(saved_state['_state']) - if 'communicator' in load_context: - self._communicator = load_context.communicator + if 'coordinator' in load_context: + self._coordinator = load_context.coordinator if 'logger' in load_context: self._logger = load_context.logger @@ -739,7 +739,7 @@ def on_entered(self, from_state: Optional[process_states.State]) -> None: elif state_label == process_states.ProcessState.KILLED: call_with_super_check(self.on_killed) - if self._communicator and isinstance(self.state, enum.Enum): + if self._coordinator and isinstance(self.state, enum.Enum): # FIXME: move all to `coordinator.broadcast()` call and in rmq implement coordinator from plumpy.rmq.exceptions import CommunicatorChannelInvalidStateError, CommunicatorConnectionClosed @@ -747,7 +747,7 @@ def on_entered(self, from_state: Optional[process_states.State]) -> None: subject = f'state_changed.{from_label}.{self.state.value}' self.logger.info('Process<%s>: Broadcasting state change: %s', self.pid, subject) try: - self._communicator.broadcast_send(body=None, sender=self.pid, subject=subject) + self._coordinator.broadcast_send(body=None, sender=self.pid, subject=subject) except (CommunicatorConnectionClosed, CommunicatorChannelInvalidStateError): message = 'Process<%s>: no connection available to broadcast state change from %s to %s' self.logger.warning(message, self.pid, from_label, self.state.value) @@ -937,7 +937,7 @@ def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> Non # region Communication - def message_receive(self, _comm: Communicator, msg: MessageType) -> Any: + def message_receive(self, _comm: Coordinator, msg: MessageType) -> Any: """ Coroutine called when the process receives a message from the communicator @@ -969,7 +969,7 @@ def message_receive(self, _comm: Communicator, msg: MessageType) -> Any: raise RuntimeError('Unknown intent') def broadcast_receive( - self, _comm: kiwipy.Communicator, msg: MessageType, sender: Any, subject: Any, correlation_id: Any + self, _comm: Coordinator, msg: MessageType, sender: Any, subject: Any, correlation_id: Any ) -> Optional[concurrent.futures.Future]: """ Coroutine called when the process receives a message from the communicator diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 7e67253f..5df20bf4 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -23,7 +23,7 @@ cast, ) -from plumpy.coordinator import Communicator +from plumpy.coordinator import Coordinator from . import lang, mixins, persistence, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE @@ -128,9 +128,9 @@ def __init__( pid: Optional[PID_TYPE] = None, logger: Optional[logging.Logger] = None, loop: Optional[asyncio.AbstractEventLoop] = None, - communicator: Optional[Communicator] = None, + coordinator: Optional[Coordinator] = None, ) -> None: - super().__init__(inputs=inputs, pid=pid, logger=logger, loop=loop, communicator=communicator) + super().__init__(inputs=inputs, pid=pid, logger=logger, loop=loop, coordinator=coordinator) self._stepper: Optional[Stepper] = None self._awaitables: Dict[Union[asyncio.Future, processes.Process], str] = {} diff --git a/tests/rmq/test_process_comms.py b/tests/rmq/test_process_comms.py index 454c0787..4d9bca29 100644 --- a/tests/rmq/test_process_comms.py +++ b/tests/rmq/test_process_comms.py @@ -45,7 +45,7 @@ def sync_controller(thread_communicator: rmq.RmqThreadCommunicator): class TestRemoteProcessController: @pytest.mark.asyncio async def test_pause(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated()) # Send a pause message @@ -57,7 +57,7 @@ async def test_pause(self, thread_communicator, async_controller): @pytest.mark.asyncio async def test_play(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated()) assert proc.pause() @@ -75,7 +75,7 @@ async def test_play(self, thread_communicator, async_controller): @pytest.mark.asyncio async def test_kill(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) # Run the process in the event loop asyncio.ensure_future(proc.step_until_terminated()) @@ -88,7 +88,7 @@ async def test_kill(self, thread_communicator, async_controller): @pytest.mark.asyncio async def test_status(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated()) @@ -108,7 +108,7 @@ def on_broadcast_receive(**msg): thread_communicator.add_broadcast_subscriber(on_broadcast_receive) - proc = utils.DummyProcess(communicator=thread_communicator) + proc = utils.DummyProcess(coordinator=thread_communicator) proc.execute() expected_subjects = [] @@ -123,7 +123,7 @@ def on_broadcast_receive(**msg): class TestRemoteProcessThreadController: @pytest.mark.asyncio async def test_pause(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) # Send a pause message pause_future = sync_controller.pause_process(proc.pid) @@ -140,7 +140,7 @@ async def test_pause_all(self, thread_communicator, sync_controller): """Test pausing all processes on a communicator""" procs = [] for _ in range(10): - procs.append(utils.WaitForSignalProcess(communicator=thread_communicator)) + procs.append(utils.WaitForSignalProcess(coordinator=thread_communicator)) sync_controller.pause_all("Slow yo' roll") # Wait until they are all paused @@ -151,7 +151,7 @@ async def test_play_all(self, thread_communicator, sync_controller): """Test pausing all processes on a communicator""" procs = [] for _ in range(10): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) procs.append(proc) proc.pause('hold tight') @@ -162,7 +162,7 @@ async def test_play_all(self, thread_communicator, sync_controller): @pytest.mark.asyncio async def test_play(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) assert proc.pause() # Send a play message @@ -176,7 +176,7 @@ async def test_play(self, thread_communicator, sync_controller): @pytest.mark.asyncio async def test_kill(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) # Send a kill message kill_future = sync_controller.kill_process(proc.pid) @@ -193,7 +193,7 @@ async def test_kill_all(self, thread_communicator, sync_controller): """Test pausing all processes on a communicator""" procs = [] for _ in range(10): - procs.append(utils.WaitForSignalProcess(communicator=thread_communicator)) + procs.append(utils.WaitForSignalProcess(coordinator=thread_communicator)) sync_controller.kill_all(msg_text='bang bang, I shot you down') await utils.wait_util(lambda: all([proc.killed() for proc in procs])) @@ -201,7 +201,7 @@ async def test_kill_all(self, thread_communicator, sync_controller): @pytest.mark.asyncio async def test_status(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated()) diff --git a/tests/test_processes.py b/tests/test_processes.py index 0b38287d..99e28de6 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -1075,7 +1075,7 @@ def on_broadcast_receive(_comm, body, sender, subject, correlation_id): messages.append({'body': body, 'subject': subject, 'sender': sender, 'correlation_id': correlation_id}) communicator.add_broadcast_subscriber(on_broadcast_receive) - proc = utils.DummyProcess(communicator=communicator) + proc = utils.DummyProcess(coordinator=communicator) proc.execute() expected_subjects = [] From f8cc8ed290e80b4cd8d8775ba020ce6aa69499ef Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Tue, 17 Dec 2024 12:30:48 +0100 Subject: [PATCH 09/22] Remove unnecessary task_send ab from RemoteProcessControl interface --- src/plumpy/rmq/process_comms.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/src/plumpy/rmq/process_comms.py b/src/plumpy/rmq/process_comms.py index f7903d6e..38355b03 100644 --- a/src/plumpy/rmq/process_comms.py +++ b/src/plumpy/rmq/process_comms.py @@ -273,7 +273,7 @@ def continue_process( self, pid: 'PID_TYPE', tag: Optional[str] = None, nowait: bool = False, no_reply: bool = False ) -> Union[None, PID_TYPE, ProcessResult]: message = create_continue_body(pid=pid, tag=tag, nowait=nowait) - return self.task_send(message, no_reply=no_reply) + return self._communicator.task_send(message, no_reply=no_reply) def launch_process( self, @@ -298,7 +298,7 @@ def launch_process( :return: the pid of the created process or the outputs (if nowait=False) """ message = create_launch_body(process_class, init_args, init_kwargs, persist, loader, nowait) - return self.task_send(message, no_reply=no_reply) + return self._communicator.task_send(message, no_reply=no_reply) def execute_process( self, @@ -335,13 +335,3 @@ def on_created(_: Any) -> None: create_future.add_done_callback(on_created) return execute_future - - def task_send(self, message: Any, no_reply: bool = False) -> Optional[Any]: - """ - Send a task to be performed using the communicator - - :param message: the task message - :param no_reply: if True, this call will be fire-and-forget, i.e. no return value - :return: the response from the remote side (if no_reply=False) - """ - return self._communicator.task_send(message, no_reply=no_reply) From 34d38427e27bdf064f2b839efce1a1e04e12629b Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Tue, 17 Dec 2024 13:09:52 +0100 Subject: [PATCH 10/22] Interface for ProcessController --- src/plumpy/__init__.py | 6 +- src/plumpy/controller.py | 113 ++++++++++++++++++++++++++++++++ src/plumpy/coordinator.py | 44 ++++++++++--- src/plumpy/rmq/process_comms.py | 61 ++++++++--------- 4 files changed, 185 insertions(+), 39 deletions(-) create mode 100644 src/plumpy/controller.py diff --git a/src/plumpy/__init__.py b/src/plumpy/__init__.py index 8f62edb6..a63be7d1 100644 --- a/src/plumpy/__init__.py +++ b/src/plumpy/__init__.py @@ -19,6 +19,10 @@ from .workchains import * from .rmq import * +# interfaces +from .controller import ProcessController +from .coordinator import Coordinator + __all__ = ( events.__all__ + exceptions.__all__ @@ -33,7 +37,7 @@ + loaders.__all__ + ports.__all__ + process_states.__all__ -) +) + ['ProcessController', 'Coordinator'] # Do this se we don't get the "No handlers could be found..." warnings that will be produced diff --git a/src/plumpy/controller.py b/src/plumpy/controller.py new file mode 100644 index 00000000..5a411fd1 --- /dev/null +++ b/src/plumpy/controller.py @@ -0,0 +1,113 @@ +from collections.abc import Sequence +from typing import Any, Protocol + +from plumpy import loaders +from plumpy.message import MessageType +from plumpy.utils import PID_TYPE + +ProcessResult = Any +ProcessStatus = Any + + +class ProcessController(Protocol): + """ + Control processes using coroutines that will send messages and wait + (in a non-blocking way) for their response + """ + + def get_status(self, pid: 'PID_TYPE') -> ProcessStatus: + """ + Get the status of a process with the given PID + :param pid: the process id + :return: the status response from the process + """ + ... + + def pause_process(self, pid: 'PID_TYPE', msg: Any | None = None) -> ProcessResult: + """ + Pause the process + + :param pid: the pid of the process to pause + :param msg: optional pause message + :return: True if paused, False otherwise + """ + ... + + def play_process(self, pid: 'PID_TYPE') -> ProcessResult: + """ + Play the process + + :param pid: the pid of the process to play + :return: True if played, False otherwise + """ + ... + + def kill_process(self, pid: 'PID_TYPE', msg: MessageType | None = None) -> ProcessResult: + """ + Kill the process + + :param pid: the pid of the process to kill + :param msg: optional kill message + :return: True if killed, False otherwise + """ + ... + + def continue_process( + self, pid: 'PID_TYPE', tag: str|None = None, nowait: bool = False, no_reply: bool = False + ) -> ProcessResult | None: + """ + Continue the process + + :param _communicator: the communicator + :param pid: the pid of the process to continue + :param tag: the checkpoint tag to continue from + """ + ... + + async def launch_process( + self, + process_class: str, + init_args: Sequence[Any] | None = None, + init_kwargs: dict[str, Any] | None = None, + persist: bool = False, + loader: loaders.ObjectLoader | None = None, + nowait: bool = False, + no_reply: bool = False, + ) -> ProcessResult: + """ + Launch a process given the class and constructor arguments + + :param process_class: the class of the process to launch + :param init_args: the constructor positional arguments + :param init_kwargs: the constructor keyword arguments + :param persist: should the process be persisted + :param loader: the classloader to use + :param nowait: if True, don't wait for the process to send a response, just return the pid + :param no_reply: if True, this call will be fire-and-forget, i.e. no return value + :return: the result of launching the process + """ + ... + + async def execute_process( + self, + process_class: str, + init_args: Sequence[Any] | None = None, + init_kwargs: dict[str, Any] | None = None, + loader: loaders.ObjectLoader | None = None, + nowait: bool = False, + no_reply: bool = False, + ) -> ProcessResult: + """ + Execute a process. This call will first send a create task and then a continue task over + the communicator. This means that if communicator messages are durable then the process + will run until the end even if this interpreter instance ceases to exist. + + :param process_class: the process class to execute + :param init_args: the positional arguments to the class constructor + :param init_kwargs: the keyword arguments to the class constructor + :param loader: the class loader to use + :param nowait: if True, don't wait for the process to send a response + :param no_reply: if True, this call will be fire-and-forget, i.e. no return value + :return: the result of executing the process + """ + ... diff --git a/src/plumpy/coordinator.py b/src/plumpy/coordinator.py index cd66a883..c229a6c4 100644 --- a/src/plumpy/coordinator.py +++ b/src/plumpy/coordinator.py @@ -1,15 +1,25 @@ # -*- coding: utf-8 -*- -from typing import Any, Callable, Pattern, Protocol +import concurrent.futures +from typing import TYPE_CHECKING, Any, Callable, Hashable, Pattern, Protocol -RpcSubscriber = Callable[['Communicator', Any], Any] -BroadcastSubscriber = Callable[['Communicator', Any, Any, Any, Any], Any] + +if TYPE_CHECKING: + # identifiers for subscribers + ID_TYPE = Hashable + Subscriber = Callable[..., Any] + # RPC subscriber params: communicator, msg + RpcSubscriber = Callable[['Coordinator', Any], Any] + # Task subscriber params: communicator, task + TaskSubscriber = Callable[['Coordinator', Any], Any] + # Broadcast subscribers params: communicator, body, sender, subject, correlation id + BroadcastSubscriber = Callable[['Coordinator', Any, Any, Any, ID_TYPE], Any] class Communicator(Protocol): - def add_rpc_subscriber(self, subscriber: RpcSubscriber, identifier=None) -> Any: ... + def add_rpc_subscriber(self, subscriber: 'RpcSubscriber', identifier: 'ID_TYPE | None' = None) -> Any: ... def add_broadcast_subscriber( - self, subscriber: BroadcastSubscriber, subject_filter: str | Pattern[str] | None = None, identifier=None + self, subscriber: 'BroadcastSubscriber', subject_filter: str | Pattern[str] | None = None, identifier=None ) -> Any: ... def remove_rpc_subscriber(self, identifier): ... @@ -18,15 +28,33 @@ def remove_broadcast_subscriber(self, identifier): ... def broadcast_send(self, body, sender=None, subject=None, correlation_id=None) -> bool: ... + class Coordinator(Protocol): - def add_rpc_subscriber(self, subscriber: RpcSubscriber, identifier=None) -> Any: ... + def add_rpc_subscriber(self, subscriber: 'RpcSubscriber', identifier=None) -> Any: ... def add_broadcast_subscriber( - self, subscriber: BroadcastSubscriber, subject_filter: str | Pattern[str] | None = None, identifier=None + self, + subscriber: 'BroadcastSubscriber', + subject_filter: str | Pattern[str] | None = None, + identifier: 'ID_TYPE | None' = None, ) -> Any: ... + def add_task_subscriber(self, subscriber: 'TaskSubscriber', identifier: 'ID_TYPE | None' = None) -> 'ID_TYPE': ... + def remove_rpc_subscriber(self, identifier): ... def remove_broadcast_subscriber(self, identifier): ... - def broadcast_send(self, body, sender=None, subject=None, correlation_id=None) -> bool: ... + def remove_task_subscriber(self, identifier: 'ID_TYPE') -> None: ... + + def rpc_send(self, recipient_id: Hashable, msg: Any) -> Any: ... + + def broadcast_send( + self, + body: Any | None, + sender: str | None = None, + subject: str | None = None, + correlation_id: 'ID_TYPE | None' = None, + ) -> Any: ... + + def task_send(self, task: Any, no_reply: bool = False) -> Any: ... diff --git a/src/plumpy/rmq/process_comms.py b/src/plumpy/rmq/process_comms.py index 38355b03..91484332 100644 --- a/src/plumpy/rmq/process_comms.py +++ b/src/plumpy/rmq/process_comms.py @@ -9,6 +9,7 @@ import kiwipy from plumpy import loaders +from plumpy.coordinator import Coordinator from plumpy.message import ( Intent, MessageBuilder, @@ -34,8 +35,8 @@ class RemoteProcessController: (in a non-blocking way) for their response """ - def __init__(self, communicator: kiwipy.Communicator) -> None: - self._communicator = communicator + def __init__(self, coordinator: Coordinator) -> None: + self._coordinator = coordinator async def get_status(self, pid: 'PID_TYPE') -> 'ProcessStatus': """ @@ -43,7 +44,7 @@ async def get_status(self, pid: 'PID_TYPE') -> 'ProcessStatus': :param pid: the process id :return: the status response from the process """ - future = self._communicator.rpc_send(pid, MessageBuilder.status()) + future = self._coordinator.rpc_send(pid, MessageBuilder.status()) result = await asyncio.wrap_future(future) return result @@ -57,8 +58,8 @@ async def pause_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) - """ msg = MessageBuilder.pause(text=msg_text) - pause_future = self._communicator.rpc_send(pid, msg) - # rpc_send return a thread future from communicator + pause_future = self._coordinator.rpc_send(pid, msg) + # rpc_send return a thread future from coordinator future = await asyncio.wrap_future(pause_future) # future is just returned from rpc call which return a kiwipy future result = await asyncio.wrap_future(future) @@ -71,7 +72,7 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult': :param pid: the pid of the process to play :return: True if played, False otherwise """ - play_future = self._communicator.rpc_send(pid, MessageBuilder.play()) + play_future = self._coordinator.rpc_send(pid, MessageBuilder.play()) future = await asyncio.wrap_future(play_future) result = await asyncio.wrap_future(future) return result @@ -87,7 +88,7 @@ async def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> msg = MessageBuilder.kill(text=msg_text) # Wait for the communication to go through - kill_future = self._communicator.rpc_send(pid, msg) + kill_future = self._coordinator.rpc_send(pid, msg) future = await asyncio.wrap_future(kill_future) # Now wait for the kill to be enacted result = await asyncio.wrap_future(future) @@ -99,13 +100,13 @@ async def continue_process( """ Continue the process - :param _communicator: the communicator + :param _coordinator: the coordinator :param pid: the pid of the process to continue :param tag: the checkpoint tag to continue from """ message = create_continue_body(pid=pid, tag=tag, nowait=nowait) # Wait for the communication to go through - continue_future = self._communicator.task_send(message, no_reply=no_reply) + continue_future = self._coordinator.task_send(message, no_reply=no_reply) future = await asyncio.wrap_future(continue_future) if no_reply: @@ -139,7 +140,7 @@ async def launch_process( """ message = create_launch_body(process_class, init_args, init_kwargs, persist, loader, nowait) - launch_future = self._communicator.task_send(message, no_reply=no_reply) + launch_future = self._coordinator.task_send(message, no_reply=no_reply) future = await asyncio.wrap_future(launch_future) if no_reply: @@ -159,7 +160,7 @@ async def execute_process( ) -> 'ProcessResult': """ Execute a process. This call will first send a create task and then a continue task over - the communicator. This means that if communicator messages are durable then the process + the coordinator. This means that if coordinator messages are durable then the process will run until the end even if this interpreter instance ceases to exist. :param process_class: the process class to execute @@ -173,12 +174,12 @@ async def execute_process( message = create_create_body(process_class, init_args, init_kwargs, persist=True, loader=loader) - create_future = self._communicator.task_send(message) + create_future = self._coordinator.task_send(message) future = await asyncio.wrap_future(create_future) pid: 'PID_TYPE' = await asyncio.wrap_future(future) message = create_continue_body(pid, nowait=nowait) - continue_future = self._communicator.task_send(message, no_reply=no_reply) + continue_future = self._coordinator.task_send(message, no_reply=no_reply) future = await asyncio.wrap_future(continue_future) if no_reply: @@ -193,14 +194,14 @@ class RemoteProcessThreadController: A class that can be used to control and launch remote processes """ - def __init__(self, communicator: kiwipy.Communicator): + def __init__(self, coordinator: Coordinator): """ Create a new process controller - :param communicator: the communicator to use + :param coordinator: the coordinator to use """ - self._communicator = communicator + self._coordinator = coordinator def get_status(self, pid: 'PID_TYPE') -> kiwipy.Future: """Get the status of a process with the given PID. @@ -208,7 +209,7 @@ def get_status(self, pid: 'PID_TYPE') -> kiwipy.Future: :param pid: the process id :return: the status response from the process """ - return self._communicator.rpc_send(pid, MessageBuilder.status()) + return self._coordinator.rpc_send(pid, MessageBuilder.status()) def pause_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwipy.Future: """ @@ -221,16 +222,16 @@ def pause_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwi """ msg = MessageBuilder.pause(text=msg_text) - return self._communicator.rpc_send(pid, msg) + return self._coordinator.rpc_send(pid, msg) def pause_all(self, msg_text: Optional[str]) -> None: """ - Pause all processes that are subscribed to the same communicator + Pause all processes that are subscribed to the same coordinator :param msg: an optional pause message """ msg = MessageBuilder.pause(text=msg_text) - self._communicator.broadcast_send(msg, subject=Intent.PAUSE) + self._coordinator.broadcast_send(msg, subject=Intent.PAUSE) def play_process(self, pid: 'PID_TYPE') -> kiwipy.Future: """ @@ -240,13 +241,13 @@ def play_process(self, pid: 'PID_TYPE') -> kiwipy.Future: :return: a response future from the process to be played """ - return self._communicator.rpc_send(pid, MessageBuilder.play()) + return self._coordinator.rpc_send(pid, MessageBuilder.play()) def play_all(self) -> None: """ - Play all processes that are subscribed to the same communicator + Play all processes that are subscribed to the same coordinator """ - self._communicator.broadcast_send(None, subject=Intent.PLAY) + self._coordinator.broadcast_send(None, subject=Intent.PLAY) def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwipy.Future: """ @@ -257,23 +258,23 @@ def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwip :return: a response future from the process to be killed """ msg = MessageBuilder.kill(text=msg_text) - return self._communicator.rpc_send(pid, msg) + return self._coordinator.rpc_send(pid, msg) def kill_all(self, msg_text: Optional[str]) -> None: """ - Kill all processes that are subscribed to the same communicator + Kill all processes that are subscribed to the same coordinator :param msg: an optional pause message """ msg = MessageBuilder.kill(msg_text) - self._communicator.broadcast_send(msg, subject=Intent.KILL) + self._coordinator.broadcast_send(msg, subject=Intent.KILL) def continue_process( self, pid: 'PID_TYPE', tag: Optional[str] = None, nowait: bool = False, no_reply: bool = False ) -> Union[None, PID_TYPE, ProcessResult]: message = create_continue_body(pid=pid, tag=tag, nowait=nowait) - return self._communicator.task_send(message, no_reply=no_reply) + return self._coordinator.task_send(message, no_reply=no_reply) def launch_process( self, @@ -298,7 +299,7 @@ def launch_process( :return: the pid of the created process or the outputs (if nowait=False) """ message = create_launch_body(process_class, init_args, init_kwargs, persist, loader, nowait) - return self._communicator.task_send(message, no_reply=no_reply) + return self._coordinator.task_send(message, no_reply=no_reply) def execute_process( self, @@ -311,7 +312,7 @@ def execute_process( ) -> Union[None, PID_TYPE, ProcessResult]: """ Execute a process. This call will first send a create task and then a continue task over - the communicator. This means that if communicator messages are durable then the process + the coordinator. This means that if coordinator messages are durable then the process will run until the end even if this interpreter instance ceases to exist. :param process_class: the process class to execute @@ -325,7 +326,7 @@ def execute_process( message = create_create_body(process_class, init_args, init_kwargs, persist=True, loader=loader) execute_future = kiwipy.Future() - create_future = self._communicator.task_send(message) + create_future = self._coordinator.task_send(message) def on_created(_: Any) -> None: with kiwipy.capture_exceptions(execute_future): From bf99d23e28c9ee5d63ec7fe63076829c5efcb00f Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Tue, 17 Dec 2024 16:30:15 +0100 Subject: [PATCH 11/22] RmqCoordinator example to show how using interface can avoid making change in kiwipy --- docs/source/concepts.md | 2 +- docs/source/tutorial.ipynb | 2 +- src/plumpy/__init__.py | 7 +- src/plumpy/controller.py | 5 +- src/plumpy/coordinator.py | 26 +--- src/plumpy/exceptions.py | 14 ++ src/plumpy/futures.py | 4 +- src/plumpy/message.py | 2 + src/plumpy/processes.py | 18 ++- src/plumpy/rmq/__init__.py | 7 +- src/plumpy/rmq/communications.py | 6 +- src/plumpy/rmq/exceptions.py | 14 -- src/plumpy/rmq/futures.py | 2 + .../{process_comms.py => process_control.py} | 13 +- tests/base/test_statemachine.py | 2 +- tests/rmq/__init__.py | 65 +++++++++ tests/rmq/test_communications.py | 76 ++++++----- tests/rmq/test_communicator.py | 77 ++++++----- ...ocess_comms.py => test_process_control.py} | 73 +++++----- tests/test_processes.py | 14 +- tests/utils.py | 128 ++++++++++++++++++ 21 files changed, 379 insertions(+), 178 deletions(-) delete mode 100644 src/plumpy/rmq/exceptions.py rename src/plumpy/rmq/{process_comms.py => process_control.py} (96%) rename tests/rmq/{test_process_comms.py => test_process_control.py} (71%) diff --git a/docs/source/concepts.md b/docs/source/concepts.md index ba6e8b17..0c39d515 100644 --- a/docs/source/concepts.md +++ b/docs/source/concepts.md @@ -32,7 +32,7 @@ WorkChains support the use of logical constructs such as `If_` and `While_` to c A `Controller` can control processes throughout their lifetime, by sending and receiving messages. It can launch, pause, continue, kill and check status of the process. -The {py:class}`~plumpy.process_comms.RemoteProcessThreadController` can communicate with the process over the thread communicator provided by {{kiwipy}} which can subscribe and send messages over the {{rabbitmq}} message broker. +The {py:class}`~plumpy.rmq.process_control.RemoteProcessThreadController` can communicate with the process over the thread communicator provided by {{kiwipy}} which can subscribe and send messages over the {{rabbitmq}} message broker. The thread communicator runs on a independent thread (event loop) and so will not be blocked by sometimes long waiting times in the process event loop. Using RabbitMQ means that even if the computer is terminated unexpectedly, messages are persisted and can be run once the computer restarts. diff --git a/docs/source/tutorial.ipynb b/docs/source/tutorial.ipynb index b544d38b..ba0dd8ca 100644 --- a/docs/source/tutorial.ipynb +++ b/docs/source/tutorial.ipynb @@ -66,7 +66,7 @@ "The {py:class}`~plumpy.workchains.WorkChain`\n", ": A subclass of `Process` that allows for running a process as a set of discrete steps (also known as instructions), with the ability to save the state of the process after each instruction has completed.\n", "\n", - "The process `Controller` (principally the {py:class}`~plumpy.process_comms.RemoteProcessThreadController`)\n", + "The process `Controller` (principally the {py:class}`~plumpy.rmq.process_control.RemoteProcessThreadController`)\n", ": To control the process or workchain throughout its lifetime." ] }, diff --git a/src/plumpy/__init__.py b/src/plumpy/__init__.py index a63be7d1..cc65ba23 100644 --- a/src/plumpy/__init__.py +++ b/src/plumpy/__init__.py @@ -4,6 +4,9 @@ import logging +# interfaces +from .controller import ProcessController +from .coordinator import Coordinator from .events import * from .exceptions import * from .futures import * @@ -19,10 +22,6 @@ from .workchains import * from .rmq import * -# interfaces -from .controller import ProcessController -from .coordinator import Coordinator - __all__ = ( events.__all__ + exceptions.__all__ diff --git a/src/plumpy/controller.py b/src/plumpy/controller.py index 5a411fd1..dcf203dc 100644 --- a/src/plumpy/controller.py +++ b/src/plumpy/controller.py @@ -1,3 +1,6 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + from collections.abc import Sequence from typing import Any, Protocol @@ -53,7 +56,7 @@ def kill_process(self, pid: 'PID_TYPE', msg: MessageType | None = None) -> Proce ... def continue_process( - self, pid: 'PID_TYPE', tag: str|None = None, nowait: bool = False, no_reply: bool = False + self, pid: 'PID_TYPE', tag: str | None = None, nowait: bool = False, no_reply: bool = False ) -> ProcessResult | None: """ Continue the process diff --git a/src/plumpy/coordinator.py b/src/plumpy/coordinator.py index c229a6c4..b3dcbec5 100644 --- a/src/plumpy/coordinator.py +++ b/src/plumpy/coordinator.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -import concurrent.futures -from typing import TYPE_CHECKING, Any, Callable, Hashable, Pattern, Protocol +from __future__ import annotations +from typing import TYPE_CHECKING, Any, Callable, Hashable, Pattern, Protocol if TYPE_CHECKING: # identifiers for subscribers @@ -15,22 +15,8 @@ BroadcastSubscriber = Callable[['Coordinator', Any, Any, Any, ID_TYPE], Any] -class Communicator(Protocol): - def add_rpc_subscriber(self, subscriber: 'RpcSubscriber', identifier: 'ID_TYPE | None' = None) -> Any: ... - - def add_broadcast_subscriber( - self, subscriber: 'BroadcastSubscriber', subject_filter: str | Pattern[str] | None = None, identifier=None - ) -> Any: ... - - def remove_rpc_subscriber(self, identifier): ... - - def remove_broadcast_subscriber(self, identifier): ... - - def broadcast_send(self, body, sender=None, subject=None, correlation_id=None) -> bool: ... - - class Coordinator(Protocol): - def add_rpc_subscriber(self, subscriber: 'RpcSubscriber', identifier=None) -> Any: ... + def add_rpc_subscriber(self, subscriber: 'RpcSubscriber', identifier: 'ID_TYPE | None' = None) -> Any: ... def add_broadcast_subscriber( self, @@ -41,9 +27,9 @@ def add_broadcast_subscriber( def add_task_subscriber(self, subscriber: 'TaskSubscriber', identifier: 'ID_TYPE | None' = None) -> 'ID_TYPE': ... - def remove_rpc_subscriber(self, identifier): ... + def remove_rpc_subscriber(self, identifier: 'ID_TYPE | None') -> None: ... - def remove_broadcast_subscriber(self, identifier): ... + def remove_broadcast_subscriber(self, identifier: 'ID_TYPE | None') -> None: ... def remove_task_subscriber(self, identifier: 'ID_TYPE') -> None: ... @@ -52,7 +38,7 @@ def rpc_send(self, recipient_id: Hashable, msg: Any) -> Any: ... def broadcast_send( self, body: Any | None, - sender: str | None = None, + sender: 'ID_TYPE | None' = None, subject: str | None = None, correlation_id: 'ID_TYPE | None' = None, ) -> Any: ... diff --git a/src/plumpy/exceptions.py b/src/plumpy/exceptions.py index 9dca8fdb..5d05ea4b 100644 --- a/src/plumpy/exceptions.py +++ b/src/plumpy/exceptions.py @@ -3,6 +3,8 @@ __all__ = [ 'ClosedError', + 'CoordinatorConnectionError', + 'CoordinatorTimeoutError', 'InvalidStateError', 'KilledError', 'PersistenceError', @@ -42,3 +44,15 @@ class ClosedError(Exception): class TaskRejectedError(Exception): """A task was rejected by the coordinacor""" + + +class CoordinatorCommunicationError(Exception): + """Generic coordinator communication error""" + + +class CoordinatorConnectionError(ConnectionError): + """Raised when coordinator cannot be connected""" + + +class CoordinatorTimeoutError(TimeoutError): + """Raised when communicate with coordinator timeout""" diff --git a/src/plumpy/futures.py b/src/plumpy/futures.py index 01be3951..f3e8a30b 100644 --- a/src/plumpy/futures.py +++ b/src/plumpy/futures.py @@ -3,6 +3,8 @@ Module containing future related methods and classes """ +from __future__ import annotations + import asyncio import contextlib from typing import Any, Awaitable, Callable, Generator, Optional @@ -18,7 +20,7 @@ class InvalidFutureError(Exception): @contextlib.contextmanager -def capture_exceptions(future: Future[Any], ignore: tuple[type[BaseException], ...] = ()) -> Generator[None, Any, None]: +def capture_exceptions(future, ignore: tuple[type[BaseException], ...] = ()) -> Generator[None, Any, None]: # type: ignore[no-untyped-def] """ Capture any exceptions in the context and set them as the result of the given future diff --git a/src/plumpy/message.py b/src/plumpy/message.py index 2d52b048..58c1c6bd 100644 --- a/src/plumpy/message.py +++ b/src/plumpy/message.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- """Module for process level communication functions and classes""" +from __future__ import annotations + import asyncio import logging from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 8753c20b..c1381471 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -740,20 +740,18 @@ def on_entered(self, from_state: Optional[process_states.State]) -> None: call_with_super_check(self.on_killed) if self._coordinator and isinstance(self.state, enum.Enum): - # FIXME: move all to `coordinator.broadcast()` call and in rmq implement coordinator - from plumpy.rmq.exceptions import CommunicatorChannelInvalidStateError, CommunicatorConnectionClosed - from_label = cast(enum.Enum, from_state.LABEL).value if from_state is not None else None subject = f'state_changed.{from_label}.{self.state.value}' self.logger.info('Process<%s>: Broadcasting state change: %s', self.pid, subject) try: self._coordinator.broadcast_send(body=None, sender=self.pid, subject=subject) - except (CommunicatorConnectionClosed, CommunicatorChannelInvalidStateError): - message = 'Process<%s>: no connection available to broadcast state change from %s to %s' - self.logger.warning(message, self.pid, from_label, self.state.value) - except concurrent.futures.TimeoutError: - message = 'Process<%s>: sending broadcast of state change from %s to %s timed out' - self.logger.warning(message, self.pid, from_label, self.state.value) + except exceptions.CoordinatorCommunicationError: + message = f'Process<{self.pid}>: cannot broadcast state change from {from_label} to {self.state.value}' + self.logger.warning(message) + self.logger.debug(message, exc_info=True) + except Exception: + # bubble up for unknown exception + raise def on_exiting(self) -> None: state = self.state @@ -1019,7 +1017,7 @@ def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) :return: a kiwi future that resolves to the outcome of the callback """ - kiwi_future = concurrent.futures.Future() + kiwi_future = concurrent.futures.Future() # type: ignore[var-annotated] async def run_callback() -> None: with capture_exceptions(kiwi_future): diff --git a/src/plumpy/rmq/__init__.py b/src/plumpy/rmq/__init__.py index ad0642ca..c44c5a2e 100644 --- a/src/plumpy/rmq/__init__.py +++ b/src/plumpy/rmq/__init__.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- -from .exceptions import * +# mypy: disable-error-code=name-defined +from .communications import * from .futures import * -from .process_comms import * +from .process_control import * -__all__ = exceptions.__all__ + communications.__all__ + futures.__all__ + process_comms.__all__ +__all__ = communications.__all__ + futures.__all__ + process_control.__all__ diff --git a/src/plumpy/rmq/communications.py b/src/plumpy/rmq/communications.py index 5e526b23..cb0012c9 100644 --- a/src/plumpy/rmq/communications.py +++ b/src/plumpy/rmq/communications.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- """Module for general kiwipy communication methods""" +from __future__ import annotations + import asyncio import functools from typing import TYPE_CHECKING, Any, Callable, Hashable, Optional @@ -130,10 +132,10 @@ def remove_task_subscriber(self, identifier: 'ID_TYPE') -> None: return self._communicator.remove_task_subscriber(identifier) def add_broadcast_subscriber( - self, subscriber: 'BroadcastSubscriber', subject_filter=None, identifier: Optional['ID_TYPE'] = None + self, subscriber: 'BroadcastSubscriber', identifier: Optional['ID_TYPE'] = None ) -> 'ID_TYPE': converted = convert_to_comm(subscriber, self._loop) - return self._communicator.add_broadcast_subscriber(converted, subject_filter, identifier) + return self._communicator.add_broadcast_subscriber(converted, identifier) def remove_broadcast_subscriber(self, identifier: 'ID_TYPE') -> None: return self._communicator.remove_broadcast_subscriber(identifier) diff --git a/src/plumpy/rmq/exceptions.py b/src/plumpy/rmq/exceptions.py deleted file mode 100644 index 02eb3c97..00000000 --- a/src/plumpy/rmq/exceptions.py +++ /dev/null @@ -1,14 +0,0 @@ -# -*- coding: utf-8 -*- -import kiwipy -from aio_pika.exceptions import ChannelInvalidStateError, ConnectionClosed - -__all__ = [ - 'CommunicatorChannelInvalidStateError', - 'CommunicatorConnectionClosed', -] - -# Alias aio_pika -CommunicatorConnectionClosed = ConnectionClosed -CommunicatorChannelInvalidStateError = ChannelInvalidStateError - -CancelledError = kiwipy.CancelledError diff --git a/src/plumpy/rmq/futures.py b/src/plumpy/rmq/futures.py index 897c8147..73e9e36f 100644 --- a/src/plumpy/rmq/futures.py +++ b/src/plumpy/rmq/futures.py @@ -2,6 +2,8 @@ # mypy: disable-error-code="no-untyped-def, no-untyped-call" """Module containing future related methods and classes""" +from __future__ import annotations + import asyncio import concurrent.futures from typing import Any diff --git a/src/plumpy/rmq/process_comms.py b/src/plumpy/rmq/process_control.py similarity index 96% rename from src/plumpy/rmq/process_comms.py rename to src/plumpy/rmq/process_control.py index 91484332..e9ed3ef8 100644 --- a/src/plumpy/rmq/process_comms.py +++ b/src/plumpy/rmq/process_control.py @@ -4,7 +4,7 @@ from __future__ import annotations import asyncio -from typing import Any, Dict, Optional, Sequence, Union +from typing import Any, Dict, Hashable, Optional, Sequence, Union import kiwipy @@ -270,6 +270,17 @@ def kill_all(self, msg_text: Optional[str]) -> None: self._coordinator.broadcast_send(msg, subject=Intent.KILL) + def notify_all(self, msg: MessageType | None, sender: Hashable | None = None, subject: str | None = None) -> None: + """ + Notify all processes by broadcasting + + :param msg: an optional pause message + """ + if msg is None: + msg = MessageBuilder.kill() + + self._coordinator.broadcast_send(msg, sender=sender, subject=subject) + def continue_process( self, pid: 'PID_TYPE', tag: Optional[str] = None, nowait: bool = False, no_reply: bool = False ) -> Union[None, PID_TYPE, ProcessResult]: diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index 3a1621a2..ddcbb8d9 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -40,7 +40,7 @@ def exit(self): super().exit() self._update_time() - def play(self, track=None): # pylint: disable=no-self-use, unused-argument + def play(self, track=None): return False def _update_time(self): diff --git a/tests/rmq/__init__.py b/tests/rmq/__init__.py index e69de29b..72078829 100644 --- a/tests/rmq/__init__.py +++ b/tests/rmq/__init__.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +import kiwipy +import concurrent.futures + +from plumpy.exceptions import CoordinatorConnectionError + + +class RmqCoordinator: + def __init__(self, comm: kiwipy.Communicator): + self._comm = comm + + # XXX: naming - `add_receiver_rpc` + def add_rpc_subscriber(self, subscriber, identifier=None): + return self._comm.add_rpc_subscriber(subscriber, identifier) + + # XXX: naming - `add_receiver_broadcast` + def add_broadcast_subscriber( + self, + subscriber, + subject_filter=None, + identifier=None, + ): + subscriber = kiwipy.BroadcastFilter(subscriber, subject=subject_filter) + return self._comm.add_broadcast_subscriber(subscriber, identifier) + + # XXX: naming - `add_reciver_task` (can be combined with two above maybe??) + def add_task_subscriber(self, subscriber, identifier=None): + return self._comm.add_task_subscriber(subscriber, identifier) + + def remove_rpc_subscriber(self, identifier): + return self._comm.remove_rpc_subscriber(identifier) + + def remove_broadcast_subscriber(self, identifier): + return self._comm.remove_broadcast_subscriber(identifier) + + def remove_task_subscriber(self, identifier): + return self._comm.remove_task_subscriber(identifier) + + # XXX: naming - `send_to` + def rpc_send(self, recipient_id, msg): + return self._comm.rpc_send(recipient_id, msg) + + # XXX: naming - `broadcast` + def broadcast_send( + self, + body, + sender=None, + subject=None, + correlation_id=None, + ): + from aio_pika.exceptions import ChannelInvalidStateError, AMQPConnectionError + + try: + rsp = self._comm.broadcast_send(body, sender, subject, correlation_id) + except (ChannelInvalidStateError, AMQPConnectionError, concurrent.futures.TimeoutError) as exc: + raise CoordinatorConnectionError from exc + else: + return rsp + + # XXX: naming - `assign_task` (this may able to be combined with send_to) + def task_send(self, task, no_reply=False): + return self._comm.task_send(task, no_reply) + + def close(self): + self._comm.close() diff --git a/tests/rmq/test_communications.py b/tests/rmq/test_communications.py index 00b7f1c6..e45994b2 100644 --- a/tests/rmq/test_communications.py +++ b/tests/rmq/test_communications.py @@ -2,75 +2,81 @@ """Tests for the :mod:`plumpy.rmq.communications` module.""" import pytest -from kiwipy import CommunicatorHelper +import kiwipy from plumpy.rmq.communications import LoopCommunicator +from . import RmqCoordinator -class Subscriber: - """Test class that mocks a subscriber.""" - - def __call__(self): - pass +@pytest.fixture +def _coordinator(): + """Return an instance of `LoopCommunicator`.""" + class _Communicator(kiwipy.CommunicatorHelper): + def task_send(self, task, no_reply=False): + pass -class Communicator(CommunicatorHelper): - def task_send(self, task, no_reply=False): - pass + def rpc_send(self, recipient_id, msg): + pass - def rpc_send(self, recipient_id, msg): - pass + def broadcast_send(self, body, sender=None, subject=None, correlation_id=None): + pass - def broadcast_send(self, body, sender=None, subject=None, correlation_id=None): - pass + comm = LoopCommunicator(_Communicator()) + coordinator = RmqCoordinator(comm) + yield coordinator -@pytest.fixture -def loop_communicator(): - """Return an instance of `LoopCommunicator`.""" - return LoopCommunicator(Communicator()) + coordinator.close() @pytest.fixture def subscriber(): - """Return an instance of `Subscriber`.""" + """Return an instance of mocked `Subscriber`.""" + + class Subscriber: + """Test class that mocks a subscriber.""" + + def __call__(self): + pass + return Subscriber() -def test_add_rpc_subscriber(loop_communicator, subscriber): +def test_add_rpc_subscriber(_coordinator, subscriber): """Test the `LoopCommunicator.add_rpc_subscriber` method.""" - assert loop_communicator.add_rpc_subscriber(subscriber) is not None + assert _coordinator.add_rpc_subscriber(subscriber) is not None identifier = 'identifier' - assert loop_communicator.add_rpc_subscriber(subscriber, identifier) == identifier + assert _coordinator.add_rpc_subscriber(subscriber, identifier) == identifier -def test_remove_rpc_subscriber(loop_communicator, subscriber): +def test_remove_rpc_subscriber(_coordinator, subscriber): """Test the `LoopCommunicator.remove_rpc_subscriber` method.""" - identifier = loop_communicator.add_rpc_subscriber(subscriber) - loop_communicator.remove_rpc_subscriber(identifier) + identifier = _coordinator.add_rpc_subscriber(subscriber) + _coordinator.remove_rpc_subscriber(identifier) -def test_add_broadcast_subscriber(loop_communicator, subscriber): +def test_add_broadcast_subscriber(_coordinator, subscriber): """Test the `LoopCommunicator.add_broadcast_subscriber` method.""" - assert loop_communicator.add_broadcast_subscriber(subscriber) is not None + assert _coordinator.add_broadcast_subscriber(subscriber) is not None identifier = 'identifier' - assert loop_communicator.add_broadcast_subscriber(subscriber, identifier=identifier) == identifier + assert _coordinator.add_broadcast_subscriber(subscriber, identifier=identifier) == identifier -def test_remove_broadcast_subscriber(loop_communicator, subscriber): +def test_remove_broadcast_subscriber(_coordinator, subscriber): """Test the `LoopCommunicator.remove_broadcast_subscriber` method.""" - identifier = loop_communicator.add_broadcast_subscriber(subscriber) - loop_communicator.remove_broadcast_subscriber(identifier) + identifier = _coordinator.add_broadcast_subscriber(subscriber) + _coordinator.remove_broadcast_subscriber(identifier) -def test_add_task_subscriber(loop_communicator, subscriber): +def test_add_task_subscriber(_coordinator, subscriber): """Test the `LoopCommunicator.add_task_subscriber` method.""" - assert loop_communicator.add_task_subscriber(subscriber) is not None + assert _coordinator.add_task_subscriber(subscriber) is not None -def test_remove_task_subscriber(loop_communicator, subscriber): +def test_remove_task_subscriber(_coordinator, subscriber): """Test the `LoopCommunicator.remove_task_subscriber` method.""" - identifier = loop_communicator.add_task_subscriber(subscriber) - loop_communicator.remove_task_subscriber(identifier) + identifier = _coordinator.add_task_subscriber(subscriber) + _coordinator.remove_task_subscriber(identifier) diff --git a/tests/rmq/test_communicator.py b/tests/rmq/test_communicator.py index e9d20db5..2d7b4787 100644 --- a/tests/rmq/test_communicator.py +++ b/tests/rmq/test_communicator.py @@ -6,16 +6,17 @@ import shutil import tempfile import uuid - -from kiwipy.rmq.communicator import kiwipy import pytest import shortuuid import yaml -from kiwipy import BroadcastFilter, rmq + +from kiwipy.rmq import RmqThreadCommunicator import plumpy -from plumpy.rmq import communications, process_comms +from plumpy.coordinator import Coordinator +from plumpy.rmq import communications, process_control +from . import RmqCoordinator from .. import utils @@ -30,14 +31,14 @@ def persister(): @pytest.fixture -def loop_communicator(): +def _coordinator(): message_exchange = f'{__file__}.{shortuuid.uuid()}' task_exchange = f'{__file__}.{shortuuid.uuid()}' task_queue = f'{__file__}.{shortuuid.uuid()}' encoder = functools.partial(yaml.dump, encoding='utf-8') decoder = functools.partial(yaml.load, Loader=yaml.FullLoader) - thread_communicator = rmq.RmqThreadCommunicator.connect( + thread_comm = RmqThreadCommunicator.connect( connection_params={'url': 'amqp://guest:guest@localhost:5672/'}, message_exchange=message_exchange, task_exchange=task_exchange, @@ -48,24 +49,24 @@ def loop_communicator(): loop = asyncio.get_event_loop() loop.set_debug(True) + comm = communications.LoopCommunicator(thread_comm, loop=loop) + coordinator = RmqCoordinator(comm) - communicator = communications.LoopCommunicator(thread_communicator, loop=loop) - - yield communicator + yield coordinator - thread_communicator.close() + coordinator.close() @pytest.fixture -def async_controller(loop_communicator: communications.LoopCommunicator): - yield process_comms.RemoteProcessController(loop_communicator) +def async_controller(_coordinator): + yield process_control.RemoteProcessController(_coordinator) class TestLoopCommunicator: """Make sure the loop communicator is working as expected""" @pytest.mark.asyncio - async def test_broadcast(self, loop_communicator): + async def test_broadcast(self, _coordinator): BROADCAST = {'body': 'present', 'sender': 'Martin', 'subject': 'sup', 'correlation_id': 420} # noqa: N806 broadcast_future = asyncio.Future() @@ -78,14 +79,14 @@ def get_broadcast(_comm, body, sender, subject, correlation_id): {'body': body, 'sender': sender, 'subject': subject, 'correlation_id': correlation_id} ) - loop_communicator.add_broadcast_subscriber(get_broadcast) - loop_communicator.broadcast_send(**BROADCAST) + _coordinator.add_broadcast_subscriber(get_broadcast) + _coordinator.broadcast_send(**BROADCAST) result = await broadcast_future assert result == BROADCAST @pytest.mark.asyncio - async def test_broadcast_filter(self, loop_communicator: kiwipy.Communicator): + async def test_broadcast_filter(self, _coordinator: Coordinator): broadcast_future = asyncio.Future() def ignore_broadcast(_comm, body, sender, subject, correlation_id): @@ -94,17 +95,15 @@ def ignore_broadcast(_comm, body, sender, subject, correlation_id): def get_broadcast(_comm, body, sender, subject, correlation_id): broadcast_future.set_result(True) - loop_communicator.add_broadcast_subscriber(ignore_broadcast, subject_filter='other') - loop_communicator.add_broadcast_subscriber(get_broadcast) - loop_communicator.broadcast_send( - **{'body': 'present', 'sender': 'Martin', 'subject': 'sup', 'correlation_id': 420} - ) + _coordinator.add_broadcast_subscriber(ignore_broadcast, subject_filter='other') + _coordinator.add_broadcast_subscriber(get_broadcast) + _coordinator.broadcast_send(**{'body': 'present', 'sender': 'Martin', 'subject': 'sup', 'correlation_id': 420}) result = await broadcast_future assert result is True @pytest.mark.asyncio - async def test_rpc(self, loop_communicator): + async def test_rpc(self, _coordinator): MSG = 'rpc this' # noqa: N806 rpc_future = asyncio.Future() @@ -114,14 +113,14 @@ def get_rpc(_comm, msg): assert loop is asyncio.get_event_loop() rpc_future.set_result(msg) - loop_communicator.add_rpc_subscriber(get_rpc, 'rpc') - loop_communicator.rpc_send('rpc', MSG) + _coordinator.add_rpc_subscriber(get_rpc, 'rpc') + _coordinator.rpc_send('rpc', MSG) result = await rpc_future assert result == MSG @pytest.mark.asyncio - async def test_task(self, loop_communicator): + async def test_task(self, _coordinator): TASK = 'task this' # noqa: N806 task_future = asyncio.Future() @@ -131,8 +130,8 @@ def get_task(_comm, msg): assert loop is asyncio.get_event_loop() task_future.set_result(msg) - loop_communicator.add_task_subscriber(get_task) - loop_communicator.task_send(TASK) + _coordinator.add_task_subscriber(get_task) + _coordinator.task_send(TASK) result = await task_future assert result == TASK @@ -140,43 +139,43 @@ def get_task(_comm, msg): class TestTaskActions: @pytest.mark.asyncio - async def test_launch(self, loop_communicator, async_controller, persister): + async def test_launch(self, _coordinator, async_controller, persister): # Let the process run to the end loop = asyncio.get_event_loop() - loop_communicator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) + _coordinator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) result = await async_controller.launch_process(utils.DummyProcess) # Check that we got a result assert result == utils.DummyProcess.EXPECTED_OUTPUTS @pytest.mark.asyncio - async def test_launch_nowait(self, loop_communicator, async_controller, persister): + async def test_launch_nowait(self, _coordinator, async_controller, persister): """Testing launching but don't wait, just get the pid""" loop = asyncio.get_event_loop() - loop_communicator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) + _coordinator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) pid = await async_controller.launch_process(utils.DummyProcess, nowait=True) assert isinstance(pid, uuid.UUID) @pytest.mark.asyncio - async def test_execute_action(self, loop_communicator, async_controller, persister): + async def test_execute_action(self, _coordinator, async_controller, persister): """Test the process execute action""" loop = asyncio.get_event_loop() - loop_communicator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) + _coordinator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) result = await async_controller.execute_process(utils.DummyProcessWithOutput) assert utils.DummyProcessWithOutput.EXPECTED_OUTPUTS == result @pytest.mark.asyncio - async def test_execute_action_nowait(self, loop_communicator, async_controller, persister): + async def test_execute_action_nowait(self, _coordinator, async_controller, persister): """Test the process execute action""" loop = asyncio.get_event_loop() - loop_communicator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) + _coordinator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) pid = await async_controller.execute_process(utils.DummyProcessWithOutput, nowait=True) assert isinstance(pid, uuid.UUID) @pytest.mark.asyncio - async def test_launch_many(self, loop_communicator, async_controller, persister): + async def test_launch_many(self, _coordinator, async_controller, persister): """Test launching multiple processes""" loop = asyncio.get_event_loop() - loop_communicator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) + _coordinator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) num_to_launch = 10 launch_futures = [] @@ -189,10 +188,10 @@ async def test_launch_many(self, loop_communicator, async_controller, persister) assert isinstance(result, uuid.UUID) @pytest.mark.asyncio - async def test_continue(self, loop_communicator, async_controller, persister): + async def test_continue(self, _coordinator, async_controller, persister): """Test continuing a saved process""" loop = asyncio.get_event_loop() - loop_communicator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) + _coordinator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) process = utils.DummyProcessWithOutput() persister.save_checkpoint(process) pid = process.pid diff --git a/tests/rmq/test_process_comms.py b/tests/rmq/test_process_control.py similarity index 71% rename from tests/rmq/test_process_comms.py rename to tests/rmq/test_process_control.py index 4d9bca29..79a98ba3 100644 --- a/tests/rmq/test_process_comms.py +++ b/tests/rmq/test_process_control.py @@ -7,45 +7,46 @@ from kiwipy import rmq import plumpy -from plumpy.message import KILL_MSG, MESSAGE_KEY -from plumpy.rmq import process_comms +from plumpy.rmq import process_control +from . import RmqCoordinator from .. import utils @pytest.fixture -def thread_communicator(): +def _coordinator(): message_exchange = f'{__file__}.{shortuuid.uuid()}' task_exchange = f'{__file__}.{shortuuid.uuid()}' task_queue = f'{__file__}.{shortuuid.uuid()}' - communicator = rmq.RmqThreadCommunicator.connect( + comm = rmq.RmqThreadCommunicator.connect( connection_params={'url': 'amqp://guest:guest@localhost:5672/'}, message_exchange=message_exchange, task_exchange=task_exchange, task_queue=task_queue, ) - communicator._loop.set_debug(True) + comm._loop.set_debug(True) + coordinator = RmqCoordinator(comm) - yield communicator + yield coordinator - communicator.close() + coordinator.close() @pytest.fixture -def async_controller(thread_communicator: rmq.RmqThreadCommunicator): - yield process_comms.RemoteProcessController(thread_communicator) +def async_controller(_coordinator): + yield process_control.RemoteProcessController(_coordinator) @pytest.fixture -def sync_controller(thread_communicator: rmq.RmqThreadCommunicator): - yield process_comms.RemoteProcessThreadController(thread_communicator) +def sync_controller(_coordinator): + yield process_control.RemoteProcessThreadController(_coordinator) class TestRemoteProcessController: @pytest.mark.asyncio - async def test_pause(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_pause(self, _coordinator, async_controller): + proc = utils.WaitForSignalProcess(coordinator=_coordinator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated()) # Send a pause message @@ -56,8 +57,8 @@ async def test_pause(self, thread_communicator, async_controller): assert proc.paused @pytest.mark.asyncio - async def test_play(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_play(self, _coordinator, async_controller): + proc = utils.WaitForSignalProcess(coordinator=_coordinator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated()) assert proc.pause() @@ -74,8 +75,8 @@ async def test_play(self, thread_communicator, async_controller): await async_controller.kill_process(proc.pid) @pytest.mark.asyncio - async def test_kill(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_kill(self, _coordinator, async_controller): + proc = utils.WaitForSignalProcess(coordinator=_coordinator) # Run the process in the event loop asyncio.ensure_future(proc.step_until_terminated()) @@ -87,8 +88,8 @@ async def test_kill(self, thread_communicator, async_controller): assert proc.state == plumpy.ProcessState.KILLED @pytest.mark.asyncio - async def test_status(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_status(self, _coordinator, async_controller): + proc = utils.WaitForSignalProcess(coordinator=_coordinator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated()) @@ -100,15 +101,15 @@ async def test_status(self, thread_communicator, async_controller): # make sure proc reach the final state await async_controller.kill_process(proc.pid) - def test_broadcast(self, thread_communicator): + def test_broadcast(self, _coordinator): messages = [] def on_broadcast_receive(**msg): messages.append(msg) - thread_communicator.add_broadcast_subscriber(on_broadcast_receive) + _coordinator.add_broadcast_subscriber(on_broadcast_receive) - proc = utils.DummyProcess(coordinator=thread_communicator) + proc = utils.DummyProcess(coordinator=_coordinator) proc.execute() expected_subjects = [] @@ -122,8 +123,8 @@ def on_broadcast_receive(**msg): class TestRemoteProcessThreadController: @pytest.mark.asyncio - async def test_pause(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_pause(self, _coordinator, sync_controller): + proc = utils.WaitForSignalProcess(coordinator=_coordinator) # Send a pause message pause_future = sync_controller.pause_process(proc.pid) @@ -136,22 +137,22 @@ async def test_pause(self, thread_communicator, sync_controller): assert proc.paused @pytest.mark.asyncio - async def test_pause_all(self, thread_communicator, sync_controller): + async def test_pause_all(self, _coordinator, sync_controller): """Test pausing all processes on a communicator""" procs = [] for _ in range(10): - procs.append(utils.WaitForSignalProcess(coordinator=thread_communicator)) + procs.append(utils.WaitForSignalProcess(coordinator=_coordinator)) sync_controller.pause_all("Slow yo' roll") # Wait until they are all paused await utils.wait_util(lambda: all([proc.paused for proc in procs])) @pytest.mark.asyncio - async def test_play_all(self, thread_communicator, sync_controller): + async def test_play_all(self, _coordinator, sync_controller): """Test pausing all processes on a communicator""" procs = [] for _ in range(10): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=_coordinator) procs.append(proc) proc.pause('hold tight') @@ -161,8 +162,8 @@ async def test_play_all(self, thread_communicator, sync_controller): await utils.wait_util(lambda: all([not proc.paused for proc in procs])) @pytest.mark.asyncio - async def test_play(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_play(self, _coordinator, sync_controller): + proc = utils.WaitForSignalProcess(coordinator=_coordinator) assert proc.pause() # Send a play message @@ -175,8 +176,8 @@ async def test_play(self, thread_communicator, sync_controller): assert proc.state == plumpy.ProcessState.CREATED @pytest.mark.asyncio - async def test_kill(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_kill(self, _coordinator, sync_controller): + proc = utils.WaitForSignalProcess(coordinator=_coordinator) # Send a kill message kill_future = sync_controller.kill_process(proc.pid) @@ -189,19 +190,19 @@ async def test_kill(self, thread_communicator, sync_controller): assert proc.state == plumpy.ProcessState.KILLED @pytest.mark.asyncio - async def test_kill_all(self, thread_communicator, sync_controller): + async def test_kill_all(self, _coordinator, sync_controller): """Test pausing all processes on a communicator""" procs = [] for _ in range(10): - procs.append(utils.WaitForSignalProcess(coordinator=thread_communicator)) + procs.append(utils.WaitForSignalProcess(coordinator=_coordinator)) sync_controller.kill_all(msg_text='bang bang, I shot you down') await utils.wait_util(lambda: all([proc.killed() for proc in procs])) assert all([proc.state == plumpy.ProcessState.KILLED for proc in procs]) @pytest.mark.asyncio - async def test_status(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_status(self, _coordinator, sync_controller): + proc = utils.WaitForSignalProcess(coordinator=_coordinator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated()) diff --git a/tests/test_processes.py b/tests/test_processes.py index 99e28de6..62a1e916 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -5,16 +5,14 @@ import enum import unittest -import kiwipy import pytest from plumpy.futures import CancellableAction -from tests import utils import plumpy from plumpy import BundleKeys, Process, ProcessState from plumpy.message import MessageBuilder from plumpy.utils import AttributesFrozendict -from tests import utils +from . import utils class ForgetToCallParent(plumpy.Process): @@ -1066,16 +1064,15 @@ def test_paused(self): self.assertSetEqual(events_tester.called, events_tester.expected_events) def test_broadcast(self): - # FIXME: here I need a mock test - communicator = kiwipy.LocalCommunicator() + coordinator = utils.MockCoordinator() messages = [] def on_broadcast_receive(_comm, body, sender, subject, correlation_id): messages.append({'body': body, 'subject': subject, 'sender': sender, 'correlation_id': correlation_id}) - communicator.add_broadcast_subscriber(on_broadcast_receive) - proc = utils.DummyProcess(coordinator=communicator) + coordinator.add_broadcast_subscriber(on_broadcast_receive) + proc = utils.DummyProcess(coordinator=coordinator) proc.execute() expected_subjects = [] @@ -1083,8 +1080,7 @@ def on_broadcast_receive(_comm, body, sender, subject, correlation_id): from_state = utils.DummyProcess.EXPECTED_STATE_SEQUENCE[i - 1].value if i != 0 else None expected_subjects.append(f'state_changed.{from_state}.{state.value}') - for i, message in enumerate(messages): - self.assertEqual(message['subject'], expected_subjects[i]) + assert [msg['subject'] for msg in messages] == expected_subjects class _RestartProcess(utils.WaitForSignalProcess): diff --git a/tests/utils.py b/tests/utils.py index 123d6e72..25936415 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,16 +3,144 @@ import asyncio import collections +import sys +from typing import Any import unittest from collections.abc import Mapping +import concurrent.futures import plumpy from plumpy import persistence, process_states, processes, utils +from plumpy.exceptions import CoordinatorConnectionError from plumpy.message import MessageBuilder +from plumpy.rmq import TaskRejected +import shortuuid Snapshot = collections.namedtuple('Snapshot', ['state', 'bundle', 'outputs']) +class MockCoordinator: + def __init__(self): + self._task_subscribers = {} + self._broadcast_subscribers = {} + self._rpc_subscribers = {} + self._closed = False + + def is_closed(self) -> bool: + return self._closed + + def close(self): + if self._closed: + return + self._closed = True + del self._task_subscribers + del self._broadcast_subscribers + del self._rpc_subscribers + + def add_rpc_subscriber(self, subscriber, identifier=None) -> Any: + self._ensure_open() + identifier = identifier or shortuuid.uuid() + if identifier in self._rpc_subscribers: + raise RuntimeError(f"Duplicate RPC subscriber with identifier '{identifier}'") + self._rpc_subscribers[identifier] = subscriber + return identifier + + def remove_rpc_subscriber(self, identifier): + self._ensure_open() + try: + self._rpc_subscribers.pop(identifier) + except KeyError as exc: + raise ValueError(f"Unknown subscriber '{identifier}'") from exc + + def add_task_subscriber(self, subscriber, identifier=None): + """ + Register a task subscriber + + :param subscriber: The task callback function + :param identifier: the subscriber identifier + """ + self._ensure_open() + identifier = identifier or shortuuid.uuid() + if identifier in self._rpc_subscribers: + raise RuntimeError(f"Duplicate RPC subscriber with identifier '{identifier}'") + self._task_subscribers[identifier] = subscriber + return identifier + + def remove_task_subscriber(self, identifier): + """ + Remove a task subscriber + + :param identifier: the subscriber to remove + :raises: ValueError if identifier does not correspond to a known subscriber + """ + self._ensure_open() + try: + self._task_subscribers.pop(identifier) + except KeyError as exception: + raise ValueError(f"Unknown subscriber: '{identifier}'") from exception + + def add_broadcast_subscriber(self, subscriber, subject_filter=None, identifier=None) -> Any: + self._ensure_open() + identifier = identifier or shortuuid.uuid() + if identifier in self._broadcast_subscribers: + raise RuntimeError(f"Duplicate RPC subscriber with identifier '{identifier}'") + + self._broadcast_subscribers[identifier] = subscriber + return identifier + + def remove_broadcast_subscriber(self, identifier): + self._ensure_open() + try: + del self._broadcast_subscribers[identifier] + except KeyError as exception: + raise ValueError(f"Broadcast subscriber '{identifier}' unknown") from exception + + def task_send(self, msg, no_reply=False): + self._ensure_open() + future = concurrent.futures.Future() + + for subscriber in self._task_subscribers.values(): + try: + result = subscriber(self, msg) + future.set_result(result) + break + except TaskRejected: + pass + except Exception: + future.set_exception(RuntimeError(sys.exc_info())) + break + + if no_reply: + return None + + return future + + def rpc_send(self, recipient_id, msg): + self._ensure_open() + try: + subscriber = self._rpc_subscribers[recipient_id] + except KeyError as exception: + raise RuntimeError(f"Unknown rpc recipient '{recipient_id}'") from exception + else: + future = concurrent.futures.Future() + try: + future.set_result(subscriber(self, msg)) + except Exception: + future.set_exception(RuntimeError(sys.exc_info())) + + return future + + def broadcast_send(self, body, sender=None, subject=None, correlation_id=None): + self._ensure_open() + for subscriber in self._broadcast_subscribers.values(): + subscriber(self, body=body, sender=sender, subject=subject, correlation_id=correlation_id) + return True + + def _ensure_open(self): + if self.is_closed(): + raise CoordinatorConnectionError + + class TestCase(unittest.TestCase): pass From 15b267b57599771ff5a7c00617433cf8128234dd Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Wed, 18 Dec 2024 11:56:22 +0100 Subject: [PATCH 12/22] broadcast subscriber has versatile filters --- src/plumpy/coordinator.py | 3 +- src/plumpy/message.py | 15 +++---- src/plumpy/processes.py | 2 +- src/plumpy/rmq/coordinator.py | 76 ++++++++++++++++++++++++++++++++++ src/plumpy/rmq/futures.py | 31 +++++++++++++- tests/rmq/test_communicator.py | 2 +- tests/utils.py | 2 +- 7 files changed, 117 insertions(+), 14 deletions(-) create mode 100644 src/plumpy/rmq/coordinator.py diff --git a/src/plumpy/coordinator.py b/src/plumpy/coordinator.py index b3dcbec5..29533bf4 100644 --- a/src/plumpy/coordinator.py +++ b/src/plumpy/coordinator.py @@ -21,7 +21,8 @@ def add_rpc_subscriber(self, subscriber: 'RpcSubscriber', identifier: 'ID_TYPE | def add_broadcast_subscriber( self, subscriber: 'BroadcastSubscriber', - subject_filter: str | Pattern[str] | None = None, + subject_filters: list[Hashable | Pattern[str]] | None = None, + sender_filters: list[Hashable | Pattern[str]] | None = None, identifier: 'ID_TYPE | None' = None, ) -> Any: ... diff --git a/src/plumpy/message.py b/src/plumpy/message.py index 58c1c6bd..813402b8 100644 --- a/src/plumpy/message.py +++ b/src/plumpy/message.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -"""Module for process level communication functions and classes""" +"""Module for process level coordination functions and classes""" from __future__ import annotations @@ -69,11 +69,11 @@ class Intent: class MessageBuilder: - """MessageBuilder will construct different messages that can passing over communicator.""" + """MessageBuilder will construct different messages that can passing over coordinator.""" @classmethod def play(cls, text: str | None = None) -> MessageType: - """The play message send over communicator.""" + """The play message send over coordinator.""" return { INTENT_KEY: Intent.PLAY, MESSAGE_KEY: text, @@ -81,7 +81,7 @@ def play(cls, text: str | None = None) -> MessageType: @classmethod def pause(cls, text: str | None = None) -> MessageType: - """The pause message send over communicator.""" + """The pause message send over coordinator.""" return { INTENT_KEY: Intent.PAUSE, MESSAGE_KEY: text, @@ -89,7 +89,7 @@ def pause(cls, text: str | None = None) -> MessageType: @classmethod def kill(cls, text: str | None = None, force_kill: bool = False) -> MessageType: - """The kill message send over communicator.""" + """The kill message send over coordinator.""" return { INTENT_KEY: Intent.KILL, MESSAGE_KEY: text, @@ -98,7 +98,7 @@ def kill(cls, text: str | None = None, force_kill: bool = False) -> MessageType: @classmethod def status(cls, text: str | None = None) -> MessageType: - """The status message send over communicator.""" + """The status message send over coordinator.""" return { INTENT_KEY: Intent.STATUS, MESSAGE_KEY: text, @@ -254,7 +254,6 @@ async def _launch( """ Launch the process - :param _communicator: the communicator :param process_class: the process class to launch :param persist: should the process be persisted :param nowait: if True only return when the process finishes @@ -288,7 +287,6 @@ async def _continue(self, pid: 'PID_TYPE', nowait: bool, tag: Optional[str] = No """ Continue the process - :param _communicator: the communicator :param pid: the pid of the process to continue :param nowait: if True don't wait for the process to complete :param tag: the checkpoint tag to continue from @@ -320,7 +318,6 @@ async def _create( """ Create the process - :param _communicator: the communicator :param process_class: the process class to create :param persist: should the process be persisted :param init_args: positional arguments to the process constructor diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index c1381471..a7963bc8 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -325,7 +325,7 @@ def init(self) -> None: try: # filter out state change broadcasts identifier = self._coordinator.add_broadcast_subscriber( - self.broadcast_receive, subject_filter=re.compile(r'^(?!state_changed).*'), identifier=str(self.pid) + self.broadcast_receive, subject_filters=[re.compile(r'^(?!state_changed).*')], identifier=str(self.pid) ) self.add_cleanup(functools.partial(self._coordinator.remove_broadcast_subscriber, identifier)) except concurrent.futures.TimeoutError: diff --git a/src/plumpy/rmq/coordinator.py b/src/plumpy/rmq/coordinator.py new file mode 100644 index 00000000..9397d307 --- /dev/null +++ b/src/plumpy/rmq/coordinator.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +import kiwipy +import concurrent.futures + +from plumpy.exceptions import CoordinatorConnectionError + +__all__ = ['RmqCoordinator'] + +class RmqCoordinator: + def __init__(self, comm: kiwipy.Communicator): + self._comm = comm + + # XXX: naming - `add_receiver_rpc` + def add_rpc_subscriber(self, subscriber, identifier=None): + return self._comm.add_rpc_subscriber(subscriber, identifier) + + # XXX: naming - `add_receiver_broadcast` + def add_broadcast_subscriber( + self, + subscriber, + subject_filters=None, + sender_filters=None, + identifier=None, + ): + subscriber = kiwipy.BroadcastFilter(subscriber) + + subject_filters = subject_filters or [] + sender_filters = sender_filters or [] + + for filter in subject_filters: + subscriber.add_subject_filter(filter) + for filter in sender_filters: + subscriber.add_sender_filter(filter) + + return self._comm.add_broadcast_subscriber(subscriber, identifier) + + # XXX: naming - `add_reciver_task` (can be combined with two above maybe??) + def add_task_subscriber(self, subscriber, identifier=None): + return self._comm.add_task_subscriber(subscriber, identifier) + + def remove_rpc_subscriber(self, identifier): + return self._comm.remove_rpc_subscriber(identifier) + + def remove_broadcast_subscriber(self, identifier): + return self._comm.remove_broadcast_subscriber(identifier) + + def remove_task_subscriber(self, identifier): + return self._comm.remove_task_subscriber(identifier) + + # XXX: naming - `send_to` + def rpc_send(self, recipient_id, msg): + return self._comm.rpc_send(recipient_id, msg) + + # XXX: naming - `broadcast` + def broadcast_send( + self, + body, + sender=None, + subject=None, + correlation_id=None, + ): + from aio_pika.exceptions import ChannelInvalidStateError, AMQPConnectionError + + try: + rsp = self._comm.broadcast_send(body, sender, subject, correlation_id) + except (ChannelInvalidStateError, AMQPConnectionError, concurrent.futures.TimeoutError) as exc: + raise CoordinatorConnectionError from exc + else: + return rsp + + # XXX: naming - `assign_task` (this may able to be combined with send_to) + def task_send(self, task, no_reply=False): + return self._comm.task_send(task, no_reply) + + def close(self): + self._comm.close() diff --git a/src/plumpy/rmq/futures.py b/src/plumpy/rmq/futures.py index 73e9e36f..0ebe0d45 100644 --- a/src/plumpy/rmq/futures.py +++ b/src/plumpy/rmq/futures.py @@ -10,7 +10,7 @@ import kiwipy -__all__ = ['wrap_to_concurrent_future'] +__all__ = ['wrap_to_concurrent_future', 'unwrap_kiwi_future'] def _convert_future_exc(exc): @@ -111,3 +111,32 @@ def wrap_to_concurrent_future(future: asyncio.Future[Any]) -> kiwipy.Future: new_future = kiwipy.Future() _chain_future(future, new_future) return new_future + +# XXX: this required in aiida-core, see if really need this unwrap. +def unwrap_kiwi_future(future: kiwipy.Future) -> kiwipy.Future: + """ + Create a kiwi future that represents the final results of a nested series of futures, + meaning that if the futures provided itself resolves to a future the returned + future will not resolve to a value until the final chain of futures is not a future + but a concrete value. If at any point in the chain a future resolves to an exception + then the returned future will also resolve to that exception. + + :param future: the future to unwrap + :return: the unwrapping future + + """ + unwrapping = kiwipy.Future() + + def unwrap(fut: kiwipy.Future) -> None: + if fut.cancelled(): + unwrapping.cancel() + else: + with kiwipy.capture_exceptions(unwrapping): + result = fut.result() + if isinstance(result, kiwipy.Future): + result.add_done_callback(unwrap) + else: + unwrapping.set_result(result) + + future.add_done_callback(unwrap) + return unwrapping diff --git a/tests/rmq/test_communicator.py b/tests/rmq/test_communicator.py index 2d7b4787..e0baaea5 100644 --- a/tests/rmq/test_communicator.py +++ b/tests/rmq/test_communicator.py @@ -95,7 +95,7 @@ def ignore_broadcast(_comm, body, sender, subject, correlation_id): def get_broadcast(_comm, body, sender, subject, correlation_id): broadcast_future.set_result(True) - _coordinator.add_broadcast_subscriber(ignore_broadcast, subject_filter='other') + _coordinator.add_broadcast_subscriber(ignore_broadcast, subject_filters=['other']) _coordinator.add_broadcast_subscriber(get_broadcast) _coordinator.broadcast_send(**{'body': 'present', 'sender': 'Martin', 'subject': 'sup', 'correlation_id': 420}) diff --git a/tests/utils.py b/tests/utils.py index 25936415..3d4458f4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -79,7 +79,7 @@ def remove_task_subscriber(self, identifier): except KeyError as exception: raise ValueError(f"Unknown subscriber: '{identifier}'") from exception - def add_broadcast_subscriber(self, subscriber, subject_filter=None, identifier=None) -> Any: + def add_broadcast_subscriber(self, subscriber, subject_filters=None, sender_filters=None, identifier=None) -> Any: self._ensure_open() identifier = identifier or shortuuid.uuid() if identifier in self._broadcast_subscribers: From 23d954c2eb38ab2358ef711ba484ee2239cda1de Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Thu, 19 Dec 2024 01:06:21 +0100 Subject: [PATCH 13/22] Generic typing for Coordinator --- src/plumpy/rmq/communications.py | 18 ++++++++++++------ src/plumpy/rmq/coordinator.py | 13 +++++++++++-- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/plumpy/rmq/communications.py b/src/plumpy/rmq/communications.py index cb0012c9..61f89fc1 100644 --- a/src/plumpy/rmq/communications.py +++ b/src/plumpy/rmq/communications.py @@ -5,7 +5,7 @@ import asyncio import functools -from typing import TYPE_CHECKING, Any, Callable, Hashable, Optional +from typing import TYPE_CHECKING, Any, Callable, Generic, Hashable, Optional, TypeVar, final import kiwipy @@ -77,10 +77,11 @@ def converted(communicator: kiwipy.Communicator, *args: Any, **kwargs: Any) -> k return converted +T = TypeVar('T', bound=kiwipy.Communicator) def wrap_communicator( - communicator: kiwipy.Communicator, loop: Optional[asyncio.AbstractEventLoop] = None -) -> 'LoopCommunicator': + communicator: T, loop: Optional[asyncio.AbstractEventLoop] = None +) -> 'LoopCommunicator[T]': """ Wrap a communicator such that all callbacks made to any subscribers are scheduled on the given event loop. @@ -100,10 +101,11 @@ def wrap_communicator( return LoopCommunicator(communicator, loop) -class LoopCommunicator(kiwipy.Communicator): # type: ignore +@final +class LoopCommunicator(Generic[T], kiwipy.Communicator): # type: ignore """Wrapper around a `kiwipy.Communicator` that schedules any subscriber messages on a given event loop.""" - def __init__(self, communicator: kiwipy.Communicator, loop: Optional[asyncio.AbstractEventLoop] = None): + def __init__(self, communicator: T, loop: Optional[asyncio.AbstractEventLoop] = None): """ :param communicator: The kiwipy communicator :param loop: The event loop to schedule callbacks on @@ -114,6 +116,10 @@ def __init__(self, communicator: kiwipy.Communicator, loop: Optional[asyncio.Abs self._communicator = communicator self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() + @property + def inner(self) -> T: + return self._communicator + def loop(self) -> asyncio.AbstractEventLoop: return self._loop @@ -152,7 +158,7 @@ def broadcast_send( sender: Optional[str] = None, subject: Optional[str] = None, correlation_id: Optional['ID_TYPE'] = None, - ) -> futures.Future: + ) -> kiwipy.Future: return self._communicator.broadcast_send(body, sender, subject, correlation_id) def is_closed(self) -> bool: diff --git a/src/plumpy/rmq/coordinator.py b/src/plumpy/rmq/coordinator.py index 9397d307..c529b61c 100644 --- a/src/plumpy/rmq/coordinator.py +++ b/src/plumpy/rmq/coordinator.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +from typing import Generic, TypeVar, final import kiwipy import concurrent.futures @@ -6,10 +7,18 @@ __all__ = ['RmqCoordinator'] -class RmqCoordinator: - def __init__(self, comm: kiwipy.Communicator): +U = TypeVar("U", bound=kiwipy.Communicator) + +@final +class RmqCoordinator(Generic[U]): + def __init__(self, comm: U): self._comm = comm + @property + def communicator(self) -> U: + """The inner communicator.""" + return self._comm + # XXX: naming - `add_receiver_rpc` def add_rpc_subscriber(self, subscriber, identifier=None): return self._comm.add_rpc_subscriber(subscriber, identifier) From 46bd2d335fdd191bd5f043920385f891252a1adb Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Thu, 19 Dec 2024 16:16:46 +0100 Subject: [PATCH 14/22] Adopt new message protocol and changes required for aiida-core support --- src/plumpy/coordinator.py | 2 ++ src/plumpy/message.py | 20 ++++++-------------- src/plumpy/processes.py | 23 +++++++++++++++-------- src/plumpy/rmq/process_control.py | 9 ++++----- tests/test_processes.py | 2 +- 5 files changed, 28 insertions(+), 28 deletions(-) diff --git a/src/plumpy/coordinator.py b/src/plumpy/coordinator.py index 29533bf4..dc501c4e 100644 --- a/src/plumpy/coordinator.py +++ b/src/plumpy/coordinator.py @@ -45,3 +45,5 @@ def broadcast_send( ) -> Any: ... def task_send(self, task: Any, no_reply: bool = False) -> Any: ... + + def close(self) -> None: ... diff --git a/src/plumpy/message.py b/src/plumpy/message.py index 813402b8..14a4e251 100644 --- a/src/plumpy/message.py +++ b/src/plumpy/message.py @@ -16,10 +16,7 @@ from .utils import PID_TYPE __all__ = [ - 'KILL_MSG', - 'PAUSE_MSG', - 'PLAY_MSG', - 'STATUS_MSG', + 'MessageBuilder', 'ProcessLauncher', 'create_continue_body', 'create_launch_body', @@ -29,7 +26,7 @@ from .processes import Process INTENT_KEY = 'intent' -MESSAGE_KEY = 'message' +MESSAGE_TEXT_KEY = 'message' FORCE_KILL_KEY = 'force_kill' @@ -42,11 +39,6 @@ class Intent: STATUS: str = 'status' -PAUSE_MSG = {INTENT_KEY: Intent.PAUSE} -PLAY_MSG = {INTENT_KEY: Intent.PLAY} -KILL_MSG = {INTENT_KEY: Intent.KILL} -STATUS_MSG = {INTENT_KEY: Intent.STATUS} - TASK_KEY = 'task' TASK_ARGS = 'args' PERSIST_KEY = 'persist' @@ -76,7 +68,7 @@ def play(cls, text: str | None = None) -> MessageType: """The play message send over coordinator.""" return { INTENT_KEY: Intent.PLAY, - MESSAGE_KEY: text, + MESSAGE_TEXT_KEY: text, } @classmethod @@ -84,7 +76,7 @@ def pause(cls, text: str | None = None) -> MessageType: """The pause message send over coordinator.""" return { INTENT_KEY: Intent.PAUSE, - MESSAGE_KEY: text, + MESSAGE_TEXT_KEY: text, } @classmethod @@ -92,7 +84,7 @@ def kill(cls, text: str | None = None, force_kill: bool = False) -> MessageType: """The kill message send over coordinator.""" return { INTENT_KEY: Intent.KILL, - MESSAGE_KEY: text, + MESSAGE_TEXT_KEY: text, FORCE_KILL_KEY: force_kill, } @@ -101,7 +93,7 @@ def status(cls, text: str | None = None) -> MessageType: """The status message send over coordinator.""" return { INTENT_KEY: Intent.STATUS, - MESSAGE_KEY: text, + MESSAGE_TEXT_KEY: text, } diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index a7963bc8..c40347b2 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -32,6 +32,8 @@ cast, ) +import kiwipy + from plumpy.coordinator import Coordinator try: @@ -321,12 +323,15 @@ def init(self) -> None: self.add_cleanup(functools.partial(self._coordinator.remove_rpc_subscriber, identifier)) except concurrent.futures.TimeoutError: self.logger.exception('Process<%s>: failed to register as an RPC subscriber', self.pid) + # XXX: handle duplicate subscribing here: see aiida-core test_duplicate_subscriber_identifier. try: # filter out state change broadcasts - identifier = self._coordinator.add_broadcast_subscriber( - self.broadcast_receive, subject_filters=[re.compile(r'^(?!state_changed).*')], identifier=str(self.pid) - ) + subscriber = kiwipy.BroadcastFilter(self.broadcast_receive, subject=re.compile(r'^(?!state_changed).*')) + identifier = self._coordinator.add_broadcast_subscriber(subscriber, identifier=str(self.pid)) + # identifier = self._coordinator.add_broadcast_subscriber( + # subscriber, subject_filters=[re.compile(r'^(?!state_changed).*')], identifier=str(self.pid) + # ) self.add_cleanup(functools.partial(self._coordinator.remove_broadcast_subscriber, identifier)) except concurrent.futures.TimeoutError: self.logger.exception('Process<%s>: failed to register as a broadcast subscriber', self.pid) @@ -787,6 +792,8 @@ def recursively_copy_dictionaries(value: Any) -> Any: self._uuid = uuid.uuid4() if self._pid is None: self._pid = self._uuid + # __import__('ipdb').set_trace() + # print("!!!!! ") @super_check def on_exit_running(self) -> None: @@ -955,9 +962,9 @@ def message_receive(self, _comm: Coordinator, msg: MessageType) -> Any: if intent == message.Intent.PLAY: return self._schedule_rpc(self.play) if intent == message.Intent.PAUSE: - return self._schedule_rpc(self.pause, msg_text=msg.get(message.MESSAGE_KEY, None)) + return self._schedule_rpc(self.pause, msg_text=msg.get(MESSAGE_TEXT_KEY, None)) if intent == message.Intent.KILL: - return self._schedule_rpc(self.kill, msg_text=msg.get(message.MESSAGE_KEY, None)) + return self._schedule_rpc(self.kill, msg_text=msg.get(MESSAGE_TEXT_KEY, None)) if intent == message.Intent.STATUS: status_info: Dict[str, Any] = {} self.get_status_info(status_info) @@ -988,9 +995,9 @@ def broadcast_receive( if subject == message.Intent.PLAY: fn = self._schedule_rpc(self.play) elif subject == message.Intent.PAUSE: - return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None)) + fn = self._schedule_rpc(self.pause, msg_text=msg.get(MESSAGE_TEXT_KEY, None)) elif subject == message.Intent.KILL: - return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None)) + fn = self._schedule_rpc(self.kill, msg_text=msg.get(MESSAGE_TEXT_KEY, None)) if fn is None: self.logger.warning( @@ -1097,7 +1104,7 @@ def transition_failed( ) self.transition_to(new_state) - def pause(self, msg_text: Optional[str] = None) -> Union[bool, CancellableAction]: + def pause(self, msg_text: str | None = None) -> Union[bool, CancellableAction]: """Pause the process. :param msg: an optional message to set as the status. The current status will be saved in the private diff --git a/src/plumpy/rmq/process_control.py b/src/plumpy/rmq/process_control.py index e9ed3ef8..0caf1d7a 100644 --- a/src/plumpy/rmq/process_control.py +++ b/src/plumpy/rmq/process_control.py @@ -29,6 +29,7 @@ ProcessStatus = Any +# FIXME: the class not fit typing of ProcessController protocol class RemoteProcessController: """ Control remote processes using coroutines that will send messages and wait @@ -189,6 +190,7 @@ async def execute_process( return result +# FIXME: the class not fit typing of ProcessController protocol class RemoteProcessThreadController: """ A class that can be used to control and launch remote processes @@ -270,15 +272,12 @@ def kill_all(self, msg_text: Optional[str]) -> None: self._coordinator.broadcast_send(msg, subject=Intent.KILL) - def notify_all(self, msg: MessageType | None, sender: Hashable | None = None, subject: str | None = None) -> None: + def notify_msg(self, msg: MessageType, sender: Hashable | None = None, subject: str | None = None) -> None: """ - Notify all processes by broadcasting + Notify all processes by broadcasting of a msg :param msg: an optional pause message """ - if msg is None: - msg = MessageBuilder.kill() - self._coordinator.broadcast_send(msg, sender=sender, subject=subject) def continue_process( diff --git a/tests/test_processes.py b/tests/test_processes.py index 62a1e916..a05d09a3 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -10,7 +10,7 @@ import plumpy from plumpy import BundleKeys, Process, ProcessState -from plumpy.message import MessageBuilder +from plumpy.message import MESSAGE_TEXT_KEY, MessageBuilder from plumpy.utils import AttributesFrozendict from . import utils From 8e23a887c8a6ca8850d0ecb07eda849873fc6d91 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 30 Dec 2024 01:14:52 +0100 Subject: [PATCH 15/22] Simpler create_task_threadsafe implementation --- .python-version | 1 - src/plumpy/coordinator.py | 3 +++ src/plumpy/futures.py | 22 +++++++++++++--------- src/plumpy/rmq/communications.py | 5 +++-- tests/rmq/__init__.py | 4 ++-- 5 files changed, 21 insertions(+), 14 deletions(-) delete mode 100644 .python-version diff --git a/.python-version b/.python-version deleted file mode 100644 index 413c7e7e..00000000 --- a/.python-version +++ /dev/null @@ -1 +0,0 @@ -aiida-core-dev-3.12 diff --git a/src/plumpy/coordinator.py b/src/plumpy/coordinator.py index dc501c4e..702ea5f5 100644 --- a/src/plumpy/coordinator.py +++ b/src/plumpy/coordinator.py @@ -16,8 +16,10 @@ class Coordinator(Protocol): + # XXX: naming - 'add_message_handler' def add_rpc_subscriber(self, subscriber: 'RpcSubscriber', identifier: 'ID_TYPE | None' = None) -> Any: ... + # XXX: naming - 'add_broadcast_handler' def add_broadcast_subscriber( self, subscriber: 'BroadcastSubscriber', @@ -26,6 +28,7 @@ def add_broadcast_subscriber( identifier: 'ID_TYPE | None' = None, ) -> Any: ... + # XXX: naming - absorbed into 'add_message_handler' def add_task_subscriber(self, subscriber: 'TaskSubscriber', identifier: 'ID_TYPE | None' = None) -> 'ID_TYPE': ... def remove_rpc_subscriber(self, identifier: 'ID_TYPE | None') -> None: ... diff --git a/src/plumpy/futures.py b/src/plumpy/futures.py index f3e8a30b..b67c0e80 100644 --- a/src/plumpy/futures.py +++ b/src/plumpy/futures.py @@ -78,12 +78,16 @@ def create_task(coro: Callable[[], Awaitable[Any]], loop: Optional[asyncio.Abstr """ loop = loop or asyncio.get_event_loop() - future = loop.create_future() - - async def run_task() -> None: - with capture_exceptions(future): - res = await coro() - future.set_result(res) - - asyncio.run_coroutine_threadsafe(run_task(), loop) - return future + # future = loop.create_future() + # + # async def run_task() -> None: + # with capture_exceptions(future): + # res = await coro() + # future.set_result(res) + # + # asyncio.run_coroutine_threadsafe(run_task(), loop) + # return future + + return asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro(), loop) + ) diff --git a/src/plumpy/rmq/communications.py b/src/plumpy/rmq/communications.py index 61f89fc1..b27b65a1 100644 --- a/src/plumpy/rmq/communications.py +++ b/src/plumpy/rmq/communications.py @@ -9,7 +9,8 @@ import kiwipy -from plumpy import futures +from plumpy.futures import create_task +from plumpy.rmq.futures import wrap_to_concurrent_future from plumpy.utils import ensure_coroutine __all__ = [ @@ -72,7 +73,7 @@ def converted(communicator: kiwipy.Communicator, *args: Any, **kwargs: Any) -> k return kiwi_future msg_fn = functools.partial(coro, communicator, *args, **kwargs) - task_future = futures.create_task(msg_fn, loop) + task_future = create_task(msg_fn, loop) return wrap_to_concurrent_future(task_future) return converted diff --git a/tests/rmq/__init__.py b/tests/rmq/__init__.py index 72078829..2845b1bb 100644 --- a/tests/rmq/__init__.py +++ b/tests/rmq/__init__.py @@ -17,10 +17,10 @@ def add_rpc_subscriber(self, subscriber, identifier=None): def add_broadcast_subscriber( self, subscriber, - subject_filter=None, + subject_filters=None, identifier=None, ): - subscriber = kiwipy.BroadcastFilter(subscriber, subject=subject_filter) + subscriber = kiwipy.BroadcastFilter(subscriber, subject=subject_filters) return self._comm.add_broadcast_subscriber(subscriber, identifier) # XXX: naming - `add_reciver_task` (can be combined with two above maybe??) From f3f30953fe5edde92d5f82496f832a225564877b Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Fri, 10 Jan 2025 18:07:32 +0100 Subject: [PATCH 16/22] Remove RmqCoordinator to tests/util only --- src/plumpy/__init__.py | 2 +- src/plumpy/futures.py | 4 +- src/plumpy/message.py | 2 - src/plumpy/processes.py | 2 + src/plumpy/rmq/communications.py | 6 +-- src/plumpy/rmq/coordinator.py | 85 -------------------------------- src/plumpy/rmq/futures.py | 3 +- tests/rmq/__init__.py | 25 ++++++++-- 8 files changed, 31 insertions(+), 98 deletions(-) delete mode 100644 src/plumpy/rmq/coordinator.py diff --git a/src/plumpy/__init__.py b/src/plumpy/__init__.py index cc65ba23..864d2226 100644 --- a/src/plumpy/__init__.py +++ b/src/plumpy/__init__.py @@ -18,9 +18,9 @@ from .process_listener import * from .process_states import * from .processes import * +from .rmq import * from .utils import * from .workchains import * -from .rmq import * __all__ = ( events.__all__ diff --git a/src/plumpy/futures.py b/src/plumpy/futures.py index b67c0e80..ed43389e 100644 --- a/src/plumpy/futures.py +++ b/src/plumpy/futures.py @@ -88,6 +88,4 @@ def create_task(coro: Callable[[], Awaitable[Any]], loop: Optional[asyncio.Abstr # asyncio.run_coroutine_threadsafe(run_task(), loop) # return future - return asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(coro(), loop) - ) + return asyncio.wrap_future(asyncio.run_coroutine_threadsafe(coro(), loop)) diff --git a/src/plumpy/message.py b/src/plumpy/message.py index 14a4e251..009f1b26 100644 --- a/src/plumpy/message.py +++ b/src/plumpy/message.py @@ -10,8 +10,6 @@ from plumpy.coordinator import Coordinator from plumpy.exceptions import PersistenceError, TaskRejectedError -from plumpy.exceptions import PersistenceError, TaskRejectedError - from . import loaders, persistence from .utils import PID_TYPE diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index c40347b2..7e82a9c3 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- """The main Process module""" +from __future__ import annotations + import abc import asyncio import concurrent.futures diff --git a/src/plumpy/rmq/communications.py b/src/plumpy/rmq/communications.py index b27b65a1..50927557 100644 --- a/src/plumpy/rmq/communications.py +++ b/src/plumpy/rmq/communications.py @@ -78,11 +78,11 @@ def converted(communicator: kiwipy.Communicator, *args: Any, **kwargs: Any) -> k return converted + T = TypeVar('T', bound=kiwipy.Communicator) -def wrap_communicator( - communicator: T, loop: Optional[asyncio.AbstractEventLoop] = None -) -> 'LoopCommunicator[T]': + +def wrap_communicator(communicator: T, loop: Optional[asyncio.AbstractEventLoop] = None) -> 'LoopCommunicator[T]': """ Wrap a communicator such that all callbacks made to any subscribers are scheduled on the given event loop. diff --git a/src/plumpy/rmq/coordinator.py b/src/plumpy/rmq/coordinator.py deleted file mode 100644 index c529b61c..00000000 --- a/src/plumpy/rmq/coordinator.py +++ /dev/null @@ -1,85 +0,0 @@ -# -*- coding: utf-8 -*- -from typing import Generic, TypeVar, final -import kiwipy -import concurrent.futures - -from plumpy.exceptions import CoordinatorConnectionError - -__all__ = ['RmqCoordinator'] - -U = TypeVar("U", bound=kiwipy.Communicator) - -@final -class RmqCoordinator(Generic[U]): - def __init__(self, comm: U): - self._comm = comm - - @property - def communicator(self) -> U: - """The inner communicator.""" - return self._comm - - # XXX: naming - `add_receiver_rpc` - def add_rpc_subscriber(self, subscriber, identifier=None): - return self._comm.add_rpc_subscriber(subscriber, identifier) - - # XXX: naming - `add_receiver_broadcast` - def add_broadcast_subscriber( - self, - subscriber, - subject_filters=None, - sender_filters=None, - identifier=None, - ): - subscriber = kiwipy.BroadcastFilter(subscriber) - - subject_filters = subject_filters or [] - sender_filters = sender_filters or [] - - for filter in subject_filters: - subscriber.add_subject_filter(filter) - for filter in sender_filters: - subscriber.add_sender_filter(filter) - - return self._comm.add_broadcast_subscriber(subscriber, identifier) - - # XXX: naming - `add_reciver_task` (can be combined with two above maybe??) - def add_task_subscriber(self, subscriber, identifier=None): - return self._comm.add_task_subscriber(subscriber, identifier) - - def remove_rpc_subscriber(self, identifier): - return self._comm.remove_rpc_subscriber(identifier) - - def remove_broadcast_subscriber(self, identifier): - return self._comm.remove_broadcast_subscriber(identifier) - - def remove_task_subscriber(self, identifier): - return self._comm.remove_task_subscriber(identifier) - - # XXX: naming - `send_to` - def rpc_send(self, recipient_id, msg): - return self._comm.rpc_send(recipient_id, msg) - - # XXX: naming - `broadcast` - def broadcast_send( - self, - body, - sender=None, - subject=None, - correlation_id=None, - ): - from aio_pika.exceptions import ChannelInvalidStateError, AMQPConnectionError - - try: - rsp = self._comm.broadcast_send(body, sender, subject, correlation_id) - except (ChannelInvalidStateError, AMQPConnectionError, concurrent.futures.TimeoutError) as exc: - raise CoordinatorConnectionError from exc - else: - return rsp - - # XXX: naming - `assign_task` (this may able to be combined with send_to) - def task_send(self, task, no_reply=False): - return self._comm.task_send(task, no_reply) - - def close(self): - self._comm.close() diff --git a/src/plumpy/rmq/futures.py b/src/plumpy/rmq/futures.py index 0ebe0d45..b0da02db 100644 --- a/src/plumpy/rmq/futures.py +++ b/src/plumpy/rmq/futures.py @@ -10,7 +10,7 @@ import kiwipy -__all__ = ['wrap_to_concurrent_future', 'unwrap_kiwi_future'] +__all__ = ['unwrap_kiwi_future', 'wrap_to_concurrent_future'] def _convert_future_exc(exc): @@ -112,6 +112,7 @@ def wrap_to_concurrent_future(future: asyncio.Future[Any]) -> kiwipy.Future: _chain_future(future, new_future) return new_future + # XXX: this required in aiida-core, see if really need this unwrap. def unwrap_kiwi_future(future: kiwipy.Future) -> kiwipy.Future: """ diff --git a/tests/rmq/__init__.py b/tests/rmq/__init__.py index 2845b1bb..3a3b9f67 100644 --- a/tests/rmq/__init__.py +++ b/tests/rmq/__init__.py @@ -1,14 +1,23 @@ # -*- coding: utf-8 -*- +from typing import Generic, TypeVar, final import kiwipy import concurrent.futures from plumpy.exceptions import CoordinatorConnectionError -class RmqCoordinator: - def __init__(self, comm: kiwipy.Communicator): +U = TypeVar('U', bound=kiwipy.Communicator) + +@final +class RmqCoordinator(Generic[U]): + def __init__(self, comm: U): self._comm = comm + @property + def communicator(self) -> U: + """The inner communicator.""" + return self._comm + # XXX: naming - `add_receiver_rpc` def add_rpc_subscriber(self, subscriber, identifier=None): return self._comm.add_rpc_subscriber(subscriber, identifier) @@ -18,9 +27,19 @@ def add_broadcast_subscriber( self, subscriber, subject_filters=None, + sender_filters=None, identifier=None, ): - subscriber = kiwipy.BroadcastFilter(subscriber, subject=subject_filters) + subscriber = kiwipy.BroadcastFilter(subscriber) + + subject_filters = subject_filters or [] + sender_filters = sender_filters or [] + + for filter in subject_filters: + subscriber.add_subject_filter(filter) + for filter in sender_filters: + subscriber.add_sender_filter(filter) + return self._comm.add_broadcast_subscriber(subscriber, identifier) # XXX: naming - `add_reciver_task` (can be combined with two above maybe??) From b577f5eb90a8f60d5adbebf7b85a4f45b48e345f Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Sun, 12 Jan 2025 00:22:23 +0100 Subject: [PATCH 17/22] Export plumpy.futures.Future --- src/plumpy/futures.py | 12 +----------- src/plumpy/processes.py | 2 ++ src/plumpy/rmq/__init__.py | 21 +++++++++++++++++---- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/plumpy/futures.py b/src/plumpy/futures.py index ed43389e..139c6069 100644 --- a/src/plumpy/futures.py +++ b/src/plumpy/futures.py @@ -9,7 +9,7 @@ import contextlib from typing import Any, Awaitable, Callable, Generator, Optional -__all__ = ['CancellableAction', 'capture_exceptions', 'create_task', 'create_task'] +__all__ = ['CancellableAction', 'Future', 'capture_exceptions', 'create_task', 'create_task'] class InvalidFutureError(Exception): @@ -78,14 +78,4 @@ def create_task(coro: Callable[[], Awaitable[Any]], loop: Optional[asyncio.Abstr """ loop = loop or asyncio.get_event_loop() - # future = loop.create_future() - # - # async def run_task() -> None: - # with capture_exceptions(future): - # res = await coro() - # future.set_result(res) - # - # asyncio.run_coroutine_threadsafe(run_task(), loop) - # return future - return asyncio.wrap_future(asyncio.run_coroutine_threadsafe(coro(), loop)) diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 7e82a9c3..4c048d9c 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -329,6 +329,7 @@ def init(self) -> None: try: # filter out state change broadcasts + # XXX: remove dep on kiwipy subscriber = kiwipy.BroadcastFilter(self.broadcast_receive, subject=re.compile(r'^(?!state_changed).*')) identifier = self._coordinator.add_broadcast_subscriber(subscriber, identifier=str(self.pid)) # identifier = self._coordinator.add_broadcast_subscriber( @@ -1332,6 +1333,7 @@ async def step(self) -> None: self._stepping = True next_state = None try: + # XXX: debug log when need to step to next state next_state = await self._run_task(self._state.execute) except process_states.Interruption as exception: # If the interruption was caused by a call to a Process method then there should diff --git a/src/plumpy/rmq/__init__.py b/src/plumpy/rmq/__init__.py index c44c5a2e..a046d229 100644 --- a/src/plumpy/rmq/__init__.py +++ b/src/plumpy/rmq/__init__.py @@ -1,7 +1,20 @@ # -*- coding: utf-8 -*- # mypy: disable-error-code=name-defined -from .communications import * -from .futures import * -from .process_control import * +from .communications import Communicator, DeliveryFailed, RemoteException, TaskRejected, wrap_communicator +from .futures import unwrap_kiwi_future, wrap_to_concurrent_future +from .process_control import RemoteProcessController, RemoteProcessThreadController -__all__ = communications.__all__ + futures.__all__ + process_control.__all__ +__all__ = [ + # communications + 'Communicator', + 'DeliveryFailed', + 'RemoteException', + # process_control + 'RemoteProcessController', + 'RemoteProcessThreadController', + 'TaskRejected', + # futures + 'unwrap_kiwi_future', + 'wrap_communicator', + 'wrap_to_concurrent_future', +] From b0877dbec67a490fcda66321b370fe75af24b1c0 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Fri, 21 Feb 2025 20:10:21 +0100 Subject: [PATCH 18/22] Remove first unnecessary `_comm` argument to subscriber --- src/plumpy/broadcast_filter.py | 58 ++++++++++++++++++++++++++++++++++ src/plumpy/coordinator.py | 6 ++-- src/plumpy/message.py | 2 +- src/plumpy/processes.py | 39 +++++++++++------------ tests/rmq/__init__.py | 39 +++++++++++++---------- tests/rmq/test_communicator.py | 33 ++++++++++++------- tests/test_processes.py | 2 +- tests/utils.py | 2 +- 8 files changed, 125 insertions(+), 56 deletions(-) create mode 100644 src/plumpy/broadcast_filter.py diff --git a/src/plumpy/broadcast_filter.py b/src/plumpy/broadcast_filter.py new file mode 100644 index 00000000..61b27095 --- /dev/null +++ b/src/plumpy/broadcast_filter.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +import re +import typing + + +class BroadcastFilter: + """A filter that can be used to limit the subjects and/or senders that will be received""" + + def __init__(self, subscriber, subject=None, sender=None): + self._subscriber = subscriber + self._subject_filters = [] + self._sender_filters = [] + if subject is not None: + self.add_subject_filter(subject) + if sender is not None: + self.add_sender_filter(sender) + + @property + def __name__(self): + return 'BroadcastFilter' + + def __call__(self, body, sender=None, subject=None, correlation_id=None): + if self.is_filtered(sender, subject): + return None + return self._subscriber(body, sender, subject, correlation_id) + + def is_filtered(self, sender, subject) -> bool: + if subject is not None and self._subject_filters and not any(check(subject) for check in self._subject_filters): + return True + + if sender is not None and self._sender_filters and not any(check(sender) for check in self._sender_filters): + return True + + return False + + def add_subject_filter(self, subject_filter): + self._subject_filters.append(self._ensure_filter(subject_filter)) + + def add_sender_filter(self, sender_filter): + self._sender_filters.append(self._ensure_filter(sender_filter)) + + @classmethod + def _ensure_filter(cls, filter_value): + if isinstance(filter_value, str): + return re.compile(filter_value.replace('.', '[.]').replace('*', '.*')).match + if isinstance(filter_value, typing.Pattern): # pylint: disable=isinstance-second-argument-not-valid-type + return filter_value.match + + return lambda val: val == filter_value + + @classmethod + def _make_regex(cls, filter_str): + """ + :param filter_str: The filter string + :type filter_str: str + :return: The regular expression object + """ + return re.compile(filter_str.replace('.', '[.]')) diff --git a/src/plumpy/coordinator.py b/src/plumpy/coordinator.py index 702ea5f5..6905c6a7 100644 --- a/src/plumpy/coordinator.py +++ b/src/plumpy/coordinator.py @@ -8,11 +8,11 @@ ID_TYPE = Hashable Subscriber = Callable[..., Any] # RPC subscriber params: communicator, msg - RpcSubscriber = Callable[['Coordinator', Any], Any] + RpcSubscriber = Callable[[Any], Any] # Task subscriber params: communicator, task - TaskSubscriber = Callable[['Coordinator', Any], Any] + TaskSubscriber = Callable[[Any], Any] # Broadcast subscribers params: communicator, body, sender, subject, correlation id - BroadcastSubscriber = Callable[['Coordinator', Any, Any, Any, ID_TYPE], Any] + BroadcastSubscriber = Callable[[Any, Any, Any, ID_TYPE], Any] class Coordinator(Protocol): diff --git a/src/plumpy/message.py b/src/plumpy/message.py index 009f1b26..098277e1 100644 --- a/src/plumpy/message.py +++ b/src/plumpy/message.py @@ -218,7 +218,7 @@ def __init__( else: self._loader = loaders.get_object_loader() - async def __call__(self, coordinator: Coordinator, task: Dict[str, Any]) -> Union[PID_TYPE, Any]: + async def call(self, task: Dict[str, Any]) -> Union[PID_TYPE, Any]: """ Receive a task. :param task: The task message diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 4c048d9c..e2c1a673 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -36,6 +36,7 @@ import kiwipy +from plumpy.broadcast_filter import BroadcastFilter from plumpy.coordinator import Coordinator try: @@ -329,12 +330,9 @@ def init(self) -> None: try: # filter out state change broadcasts - # XXX: remove dep on kiwipy - subscriber = kiwipy.BroadcastFilter(self.broadcast_receive, subject=re.compile(r'^(?!state_changed).*')) + subscriber = BroadcastFilter(self.broadcast_receive, subject=re.compile(r'^(?!state_changed).*')) identifier = self._coordinator.add_broadcast_subscriber(subscriber, identifier=str(self.pid)) - # identifier = self._coordinator.add_broadcast_subscriber( - # subscriber, subject_filters=[re.compile(r'^(?!state_changed).*')], identifier=str(self.pid) - # ) + self.add_cleanup(functools.partial(self._coordinator.remove_broadcast_subscriber, identifier)) except concurrent.futures.TimeoutError: self.logger.exception('Process<%s>: failed to register as a broadcast subscriber', self.pid) @@ -945,7 +943,7 @@ def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> Non # region Communication - def message_receive(self, _comm: Coordinator, msg: MessageType) -> Any: + def message_receive(self, msg: MessageType) -> Any: """ Coroutine called when the process receives a message from the communicator @@ -953,12 +951,12 @@ def message_receive(self, _comm: Coordinator, msg: MessageType) -> Any: :param msg: the message :return: the outcome of processing the message, the return value will be sent back as a response to the sender """ - self.logger.debug( - "Process<%s>: received RPC message with communicator '%s': %r", - self.pid, - _comm, - msg, - ) + # self.logger.debug( + # "Process<%s>: received RPC message with communicator '%s': %r", + # self.pid, + # _comm, + # msg, + # ) intent = msg[message.INTENT_KEY] @@ -977,22 +975,21 @@ def message_receive(self, _comm: Coordinator, msg: MessageType) -> Any: raise RuntimeError('Unknown intent') def broadcast_receive( - self, _comm: Coordinator, msg: MessageType, sender: Any, subject: Any, correlation_id: Any + self, msg: MessageType, sender: Any, subject: Any, correlation_id: Any ) -> Optional[concurrent.futures.Future]: """ Coroutine called when the process receives a message from the communicator - :param _comm: the communicator that sent the message :param msg: the message """ - self.logger.debug( - "Process<%s>: received broadcast message '%s' with communicator '%s': %r", - self.pid, - subject, - _comm, - msg, - ) + # self.logger.debug( + # "Process<%s>: received broadcast message '%s' with communicator '%s': %r", + # self.pid, + # subject, + # _comm, + # msg, + # ) # If we get a message we recognise then action it, otherwise ignore fn = None if subject == message.Intent.PLAY: diff --git a/tests/rmq/__init__.py b/tests/rmq/__init__.py index 3a3b9f67..70205984 100644 --- a/tests/rmq/__init__.py +++ b/tests/rmq/__init__.py @@ -1,13 +1,19 @@ # -*- coding: utf-8 -*- -from typing import Generic, TypeVar, final +from __future__ import annotations +from re import Pattern +from typing import TYPE_CHECKING, Generic, Hashable, TypeVar, final import kiwipy import concurrent.futures from plumpy.exceptions import CoordinatorConnectionError +if TYPE_CHECKING: + ID_TYPE = Hashable + BroadcastSubscriber = Callable[[Any, Any, Any, ID_TYPE], Any] U = TypeVar('U', bound=kiwipy.Communicator) + @final class RmqCoordinator(Generic[U]): def __init__(self, comm: U): @@ -20,31 +26,30 @@ def communicator(self) -> U: # XXX: naming - `add_receiver_rpc` def add_rpc_subscriber(self, subscriber, identifier=None): - return self._comm.add_rpc_subscriber(subscriber, identifier) + def _subscriber(_, *args, **kwargs): + return subscriber(*args, **kwargs) + + return self._comm.add_rpc_subscriber(_subscriber, identifier) # XXX: naming - `add_receiver_broadcast` def add_broadcast_subscriber( self, - subscriber, - subject_filters=None, - sender_filters=None, - identifier=None, + subscriber: 'BroadcastSubscriber', + subject_filters: list[Hashable | Pattern[str]] | None = None, + sender_filters: list[Hashable | Pattern[str]] | None = None, + identifier: 'ID_TYPE | None' = None, ): - subscriber = kiwipy.BroadcastFilter(subscriber) - - subject_filters = subject_filters or [] - sender_filters = sender_filters or [] + def _subscriber(_, *args, **kwargs): + return subscriber(*args, **kwargs) - for filter in subject_filters: - subscriber.add_subject_filter(filter) - for filter in sender_filters: - subscriber.add_sender_filter(filter) - - return self._comm.add_broadcast_subscriber(subscriber, identifier) + return self._comm.add_broadcast_subscriber(_subscriber, identifier) # XXX: naming - `add_reciver_task` (can be combined with two above maybe??) def add_task_subscriber(self, subscriber, identifier=None): - return self._comm.add_task_subscriber(subscriber, identifier) + async def _subscriber(_comm, *args, **kwargs): + return await subscriber(*args, **kwargs) + + return self._comm.add_task_subscriber(_subscriber, identifier) def remove_rpc_subscriber(self, identifier): return self._comm.remove_rpc_subscriber(identifier) diff --git a/tests/rmq/test_communicator.py b/tests/rmq/test_communicator.py index e0baaea5..139f6434 100644 --- a/tests/rmq/test_communicator.py +++ b/tests/rmq/test_communicator.py @@ -13,6 +13,7 @@ from kiwipy.rmq import RmqThreadCommunicator import plumpy +from plumpy.broadcast_filter import BroadcastFilter from plumpy.coordinator import Coordinator from plumpy.rmq import communications, process_control @@ -72,7 +73,7 @@ async def test_broadcast(self, _coordinator): loop = asyncio.get_event_loop() - def get_broadcast(_comm, body, sender, subject, correlation_id): + def get_broadcast(body, sender, subject, correlation_id): assert loop is asyncio.get_event_loop() broadcast_future.set_result( @@ -89,13 +90,13 @@ def get_broadcast(_comm, body, sender, subject, correlation_id): async def test_broadcast_filter(self, _coordinator: Coordinator): broadcast_future = asyncio.Future() - def ignore_broadcast(_comm, body, sender, subject, correlation_id): + def ignore_broadcast(body, sender, subject, correlation_id): broadcast_future.set_exception(AssertionError('broadcast received')) - def get_broadcast(_comm, body, sender, subject, correlation_id): + def get_broadcast(body, sender, subject, correlation_id): broadcast_future.set_result(True) - _coordinator.add_broadcast_subscriber(ignore_broadcast, subject_filters=['other']) + _coordinator.add_broadcast_subscriber(BroadcastFilter(ignore_broadcast, subject='other')) _coordinator.add_broadcast_subscriber(get_broadcast) _coordinator.broadcast_send(**{'body': 'present', 'sender': 'Martin', 'subject': 'sup', 'correlation_id': 420}) @@ -109,7 +110,7 @@ async def test_rpc(self, _coordinator): loop = asyncio.get_event_loop() - def get_rpc(_comm, msg): + def get_rpc(msg): assert loop is asyncio.get_event_loop() rpc_future.set_result(msg) @@ -126,13 +127,15 @@ async def test_task(self, _coordinator): loop = asyncio.get_event_loop() - def get_task(_comm, msg): + def get_task(msg): assert loop is asyncio.get_event_loop() task_future.set_result(msg) _coordinator.add_task_subscriber(get_task) _coordinator.task_send(TASK) + # TODO: Error in the event loop log although the test pass + # The issue exist before rmq-out refactoring. result = await task_future assert result == TASK @@ -142,7 +145,8 @@ class TestTaskActions: async def test_launch(self, _coordinator, async_controller, persister): # Let the process run to the end loop = asyncio.get_event_loop() - _coordinator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) + launcher = plumpy.ProcessLauncher(loop, persister=persister) + _coordinator.add_task_subscriber(launcher.call) result = await async_controller.launch_process(utils.DummyProcess) # Check that we got a result assert result == utils.DummyProcess.EXPECTED_OUTPUTS @@ -151,7 +155,8 @@ async def test_launch(self, _coordinator, async_controller, persister): async def test_launch_nowait(self, _coordinator, async_controller, persister): """Testing launching but don't wait, just get the pid""" loop = asyncio.get_event_loop() - _coordinator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) + launcher = plumpy.ProcessLauncher(loop, persister=persister) + _coordinator.add_task_subscriber(launcher.call) pid = await async_controller.launch_process(utils.DummyProcess, nowait=True) assert isinstance(pid, uuid.UUID) @@ -159,7 +164,8 @@ async def test_launch_nowait(self, _coordinator, async_controller, persister): async def test_execute_action(self, _coordinator, async_controller, persister): """Test the process execute action""" loop = asyncio.get_event_loop() - _coordinator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) + launcher = plumpy.ProcessLauncher(loop, persister=persister) + _coordinator.add_task_subscriber(launcher.call) result = await async_controller.execute_process(utils.DummyProcessWithOutput) assert utils.DummyProcessWithOutput.EXPECTED_OUTPUTS == result @@ -167,7 +173,8 @@ async def test_execute_action(self, _coordinator, async_controller, persister): async def test_execute_action_nowait(self, _coordinator, async_controller, persister): """Test the process execute action""" loop = asyncio.get_event_loop() - _coordinator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) + launcher = plumpy.ProcessLauncher(loop, persister=persister) + _coordinator.add_task_subscriber(launcher.call) pid = await async_controller.execute_process(utils.DummyProcessWithOutput, nowait=True) assert isinstance(pid, uuid.UUID) @@ -175,7 +182,8 @@ async def test_execute_action_nowait(self, _coordinator, async_controller, persi async def test_launch_many(self, _coordinator, async_controller, persister): """Test launching multiple processes""" loop = asyncio.get_event_loop() - _coordinator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) + launcher = plumpy.ProcessLauncher(loop, persister=persister) + _coordinator.add_task_subscriber(launcher.call) num_to_launch = 10 launch_futures = [] @@ -191,7 +199,8 @@ async def test_launch_many(self, _coordinator, async_controller, persister): async def test_continue(self, _coordinator, async_controller, persister): """Test continuing a saved process""" loop = asyncio.get_event_loop() - _coordinator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) + launcher = plumpy.ProcessLauncher(loop, persister=persister) + _coordinator.add_task_subscriber(launcher.call) process = utils.DummyProcessWithOutput() persister.save_checkpoint(process) pid = process.pid diff --git a/tests/test_processes.py b/tests/test_processes.py index a05d09a3..0373b037 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -1068,7 +1068,7 @@ def test_broadcast(self): messages = [] - def on_broadcast_receive(_comm, body, sender, subject, correlation_id): + def on_broadcast_receive(body, sender, subject, correlation_id): messages.append({'body': body, 'subject': subject, 'sender': sender, 'correlation_id': correlation_id}) coordinator.add_broadcast_subscriber(on_broadcast_receive) diff --git a/tests/utils.py b/tests/utils.py index 3d4458f4..323a3282 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -133,7 +133,7 @@ def rpc_send(self, recipient_id, msg): def broadcast_send(self, body, sender=None, subject=None, correlation_id=None): self._ensure_open() for subscriber in self._broadcast_subscribers.values(): - subscriber(self, body=body, sender=sender, subject=subject, correlation_id=correlation_id) + subscriber(body=body, sender=sender, subject=subject, correlation_id=correlation_id) return True def _ensure_open(self): From 4b0267cc711871be10a00fd95a90d4ccf5a68449 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Fri, 21 Feb 2025 23:11:48 +0100 Subject: [PATCH 19/22] Protocol fulfill for RemoteProcessThreadController --- src/plumpy/controller.py | 68 ++++++++++++++++++++----------- src/plumpy/rmq/process_control.py | 15 +++---- 2 files changed, 50 insertions(+), 33 deletions(-) diff --git a/src/plumpy/controller.py b/src/plumpy/controller.py index dcf203dc..9f2793a0 100644 --- a/src/plumpy/controller.py +++ b/src/plumpy/controller.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Any, Protocol +from typing import Any, Hashable, Optional, Protocol, Union from plumpy import loaders from plumpy.message import MessageType @@ -26,7 +26,7 @@ def get_status(self, pid: 'PID_TYPE') -> ProcessStatus: """ ... - def pause_process(self, pid: 'PID_TYPE', msg: Any | None = None) -> ProcessResult: + def pause_process(self, pid: 'PID_TYPE', msg: str | None = None) -> ProcessResult: """ Pause the process @@ -36,18 +36,27 @@ def pause_process(self, pid: 'PID_TYPE', msg: Any | None = None) -> ProcessResul """ ... - def play_process(self, pid: 'PID_TYPE') -> ProcessResult: + def pause_all(self, msg_text: str | None) -> None: + """Pause all processes that are subscribed to the same coordinator + + :param msg_text: an optional pause message text """ - Play the process + ... + + def play_process(self, pid: 'PID_TYPE') -> ProcessResult: + """Play the process :param pid: the pid of the process to play :return: True if played, False otherwise """ ... - def kill_process(self, pid: 'PID_TYPE', msg: MessageType | None = None) -> ProcessResult: + def play_all(self) -> None: + """Play all processes that are subscribed to the same coordinator """ - Kill the process + + def kill_process(self, pid: 'PID_TYPE', msg_text: str | None = None) -> Any: + """Kill the process :param pid: the pid of the process to kill :param msg: optional kill message @@ -55,11 +64,24 @@ def kill_process(self, pid: 'PID_TYPE', msg: MessageType | None = None) -> Proce """ ... - def continue_process( - self, pid: 'PID_TYPE', tag: str | None = None, nowait: bool = False, no_reply: bool = False - ) -> ProcessResult | None: + def kill_all(self, msg_text: Optional[str]) -> None: + """Kill all processes that are subscribed to the same coordinator + + :param msg: an optional pause message """ - Continue the process + ... + + def notify_msg(self, msg: MessageType, sender: Hashable | None = None, subject: str | None = None) -> None: + """ + Notify all processes by broadcasting of a msg + + :param msg: an optional pause message + """ + + def continue_process( + self, pid: 'PID_TYPE', tag: Optional[str] = None, nowait: bool = False, no_reply: bool = False + ) -> Union[None, PID_TYPE, ProcessResult]: + """Continue the process :param _communicator: the communicator :param pid: the pid of the process to continue @@ -67,18 +89,17 @@ def continue_process( """ ... - async def launch_process( + def launch_process( self, process_class: str, - init_args: Sequence[Any] | None = None, - init_kwargs: dict[str, Any] | None = None, + init_args: Optional[Sequence[Any]] = None, + init_kwargs: Optional[dict[str, Any]] = None, persist: bool = False, - loader: loaders.ObjectLoader | None = None, + loader: Optional[loaders.ObjectLoader] = None, nowait: bool = False, no_reply: bool = False, - ) -> ProcessResult: - """ - Launch a process given the class and constructor arguments + ) -> Union[None, PID_TYPE, ProcessResult]: + """Launch a process given the class and constructor arguments :param process_class: the class of the process to launch :param init_args: the constructor positional arguments @@ -91,17 +112,16 @@ async def launch_process( """ ... - async def execute_process( + def execute_process( self, process_class: str, - init_args: Sequence[Any] | None = None, - init_kwargs: dict[str, Any] | None = None, - loader: loaders.ObjectLoader | None = None, + init_args: Optional[Sequence[Any]] = None, + init_kwargs: Optional[dict[str, Any]] = None, + loader: Optional[loaders.ObjectLoader] = None, nowait: bool = False, no_reply: bool = False, - ) -> ProcessResult: - """ - Execute a process. This call will first send a create task and then a continue task over + ) -> Union[None, PID_TYPE, ProcessResult]: + """Execute a process. This call will first send a create task and then a continue task over the communicator. This means that if communicator messages are durable then the process will run until the end even if this interpreter instance ceases to exist. diff --git a/src/plumpy/rmq/process_control.py b/src/plumpy/rmq/process_control.py index 0caf1d7a..02eb8853 100644 --- a/src/plumpy/rmq/process_control.py +++ b/src/plumpy/rmq/process_control.py @@ -190,7 +190,6 @@ async def execute_process( return result -# FIXME: the class not fit typing of ProcessController protocol class RemoteProcessThreadController: """ A class that can be used to control and launch remote processes @@ -213,9 +212,8 @@ def get_status(self, pid: 'PID_TYPE') -> kiwipy.Future: """ return self._coordinator.rpc_send(pid, MessageBuilder.status()) - def pause_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwipy.Future: - """ - Pause the process + def pause_process(self, pid: 'PID_TYPE', msg_text: str | None = None) -> kiwipy.Future: + """Pause the process :param pid: the pid of the process to pause :param msg: optional pause message @@ -226,11 +224,10 @@ def pause_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwi return self._coordinator.rpc_send(pid, msg) - def pause_all(self, msg_text: Optional[str]) -> None: - """ - Pause all processes that are subscribed to the same coordinator + def pause_all(self, msg_text: str | None) -> None: + """Pause all processes that are subscribed to the same coordinator - :param msg: an optional pause message + :param msg_text: an optional pause message text """ msg = MessageBuilder.pause(text=msg_text) self._coordinator.broadcast_send(msg, subject=Intent.PAUSE) @@ -251,7 +248,7 @@ def play_all(self) -> None: """ self._coordinator.broadcast_send(None, subject=Intent.PLAY) - def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwipy.Future: + def kill_process(self, pid: 'PID_TYPE', msg_text: str | None = None) -> kiwipy.Future: """ Kill the process From c4c50e0e3207bae8b76d4b9b0687063f0e1557d6 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Fri, 21 Feb 2025 23:45:01 +0100 Subject: [PATCH 20/22] Rename interfaces of Coordinator subscriber -> receiver to distinguish interfaces from RMQ communicator. --- src/plumpy/coordinator.py | 39 ++++++----- src/plumpy/processes.py | 10 +-- tests/rmq/__init__.py | 60 +++++++++-------- tests/rmq/test_communications.py | 43 ++++++------ tests/rmq/test_communicator.py | 22 +++---- tests/rmq/test_process_control.py | 2 +- tests/test_processes.py | 2 +- tests/utils.py | 105 +++++++++++++++++------------- 8 files changed, 153 insertions(+), 130 deletions(-) diff --git a/src/plumpy/coordinator.py b/src/plumpy/coordinator.py index 6905c6a7..e647961e 100644 --- a/src/plumpy/coordinator.py +++ b/src/plumpy/coordinator.py @@ -1,43 +1,42 @@ # -*- coding: utf-8 -*- from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Hashable, Pattern, Protocol +from typing import TYPE_CHECKING, Any, Callable, Hashable, Protocol +from re import Pattern if TYPE_CHECKING: - # identifiers for subscribers ID_TYPE = Hashable - Subscriber = Callable[..., Any] - # RPC subscriber params: communicator, msg - RpcSubscriber = Callable[[Any], Any] - # Task subscriber params: communicator, task - TaskSubscriber = Callable[[Any], Any] - # Broadcast subscribers params: communicator, body, sender, subject, correlation id - BroadcastSubscriber = Callable[[Any, Any, Any, ID_TYPE], Any] + Receiver = Callable[..., Any] class Coordinator(Protocol): - # XXX: naming - 'add_message_handler' - def add_rpc_subscriber(self, subscriber: 'RpcSubscriber', identifier: 'ID_TYPE | None' = None) -> Any: ... + def hook_rpc_receiver( + self, + receiver: 'Receiver', + identifier: 'ID_TYPE | None' = None, + ) -> Any: ... - # XXX: naming - 'add_broadcast_handler' - def add_broadcast_subscriber( + def hook_broadcast_receiver( self, - subscriber: 'BroadcastSubscriber', + receiver: 'Receiver', subject_filters: list[Hashable | Pattern[str]] | None = None, sender_filters: list[Hashable | Pattern[str]] | None = None, identifier: 'ID_TYPE | None' = None, ) -> Any: ... - # XXX: naming - absorbed into 'add_message_handler' - def add_task_subscriber(self, subscriber: 'TaskSubscriber', identifier: 'ID_TYPE | None' = None) -> 'ID_TYPE': ... + def hook_task_receiver( + self, + receiver: 'Receiver', + identifier: 'ID_TYPE | None' = None, + ) -> 'ID_TYPE': ... - def remove_rpc_subscriber(self, identifier: 'ID_TYPE | None') -> None: ... + def unhook_rpc_receiver(self, identifier: 'ID_TYPE | None') -> None: ... - def remove_broadcast_subscriber(self, identifier: 'ID_TYPE | None') -> None: ... + def unhook_broadcast_receiver(self, identifier: 'ID_TYPE | None') -> None: ... - def remove_task_subscriber(self, identifier: 'ID_TYPE') -> None: ... + def unhook_task_receiver(self, identifier: 'ID_TYPE') -> None: ... - def rpc_send(self, recipient_id: Hashable, msg: Any) -> Any: ... + def rpc_send(self, recipient_id: Hashable, msg: Any,) -> Any: ... def broadcast_send( self, diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index e2c1a673..75737574 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -322,18 +322,18 @@ def init(self) -> None: if self._coordinator is not None: try: - identifier = self._coordinator.add_rpc_subscriber(self.message_receive, identifier=str(self.pid)) - self.add_cleanup(functools.partial(self._coordinator.remove_rpc_subscriber, identifier)) + identifier = self._coordinator.hook_rpc_receiver(self.message_receive, identifier=str(self.pid)) + self.add_cleanup(functools.partial(self._coordinator.unhook_rpc_receiver, identifier)) except concurrent.futures.TimeoutError: self.logger.exception('Process<%s>: failed to register as an RPC subscriber', self.pid) - # XXX: handle duplicate subscribing here: see aiida-core test_duplicate_subscriber_identifier. + # XXX: handle duplicate subscribing here: see aiida-core test_duplicate_subscriber_identifier. try: # filter out state change broadcasts subscriber = BroadcastFilter(self.broadcast_receive, subject=re.compile(r'^(?!state_changed).*')) - identifier = self._coordinator.add_broadcast_subscriber(subscriber, identifier=str(self.pid)) + identifier = self._coordinator.hook_broadcast_receiver(subscriber, identifier=str(self.pid)) - self.add_cleanup(functools.partial(self._coordinator.remove_broadcast_subscriber, identifier)) + self.add_cleanup(functools.partial(self._coordinator.unhook_broadcast_receiver, identifier)) except concurrent.futures.TimeoutError: self.logger.exception('Process<%s>: failed to register as a broadcast subscriber', self.pid) diff --git a/tests/rmq/__init__.py b/tests/rmq/__init__.py index 70205984..91af1549 100644 --- a/tests/rmq/__init__.py +++ b/tests/rmq/__init__.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import annotations from re import Pattern -from typing import TYPE_CHECKING, Generic, Hashable, TypeVar, final +from typing import TYPE_CHECKING, Any, Callable, Generic, Hashable, TypeVar, final import kiwipy import concurrent.futures @@ -9,7 +9,7 @@ if TYPE_CHECKING: ID_TYPE = Hashable - BroadcastSubscriber = Callable[[Any, Any, Any, ID_TYPE], Any] + Receiver = Callable[..., Any] U = TypeVar('U', bound=kiwipy.Communicator) @@ -24,54 +24,61 @@ def communicator(self) -> U: """The inner communicator.""" return self._comm - # XXX: naming - `add_receiver_rpc` - def add_rpc_subscriber(self, subscriber, identifier=None): + def hook_rpc_receiver( + self, + receiver: 'Receiver', + identifier: 'ID_TYPE | None' = None, + ) -> Any: def _subscriber(_, *args, **kwargs): - return subscriber(*args, **kwargs) + return receiver(*args, **kwargs) return self._comm.add_rpc_subscriber(_subscriber, identifier) - # XXX: naming - `add_receiver_broadcast` - def add_broadcast_subscriber( + def hook_broadcast_receiver( self, - subscriber: 'BroadcastSubscriber', + receiver: 'Receiver', subject_filters: list[Hashable | Pattern[str]] | None = None, sender_filters: list[Hashable | Pattern[str]] | None = None, identifier: 'ID_TYPE | None' = None, - ): + ) -> Any: def _subscriber(_, *args, **kwargs): - return subscriber(*args, **kwargs) + return receiver(*args, **kwargs) return self._comm.add_broadcast_subscriber(_subscriber, identifier) - # XXX: naming - `add_reciver_task` (can be combined with two above maybe??) - def add_task_subscriber(self, subscriber, identifier=None): + def hook_task_receiver( + self, + receiver: 'Receiver', + identifier: 'ID_TYPE | None' = None, + ) -> 'ID_TYPE': async def _subscriber(_comm, *args, **kwargs): - return await subscriber(*args, **kwargs) + return await receiver(*args, **kwargs) return self._comm.add_task_subscriber(_subscriber, identifier) - def remove_rpc_subscriber(self, identifier): + def unhook_rpc_receiver(self, identifier: 'ID_TYPE | None') -> None: return self._comm.remove_rpc_subscriber(identifier) - def remove_broadcast_subscriber(self, identifier): + def unhook_broadcast_receiver(self, identifier: 'ID_TYPE | None') -> None: return self._comm.remove_broadcast_subscriber(identifier) - def remove_task_subscriber(self, identifier): + def unhook_task_receiver(self, identifier: 'ID_TYPE') -> None: return self._comm.remove_task_subscriber(identifier) - # XXX: naming - `send_to` - def rpc_send(self, recipient_id, msg): + def rpc_send( + self, + recipient_id: Hashable, + msg: Any, + ) -> Any: return self._comm.rpc_send(recipient_id, msg) - # XXX: naming - `broadcast` def broadcast_send( self, - body, - sender=None, - subject=None, - correlation_id=None, - ): + body: Any | None, + sender: 'ID_TYPE | None' = None, + subject: str | None = None, + correlation_id: 'ID_TYPE | None' = None, + ) -> Any: from aio_pika.exceptions import ChannelInvalidStateError, AMQPConnectionError try: @@ -81,9 +88,8 @@ def broadcast_send( else: return rsp - # XXX: naming - `assign_task` (this may able to be combined with send_to) - def task_send(self, task, no_reply=False): + def task_send(self, task: Any, no_reply: bool = False) -> Any: return self._comm.task_send(task, no_reply) - def close(self): + def close(self) -> None: self._comm.close() diff --git a/tests/rmq/test_communications.py b/tests/rmq/test_communications.py index e45994b2..92a1dd3f 100644 --- a/tests/rmq/test_communications.py +++ b/tests/rmq/test_communications.py @@ -31,7 +31,7 @@ def broadcast_send(self, body, sender=None, subject=None, correlation_id=None): @pytest.fixture -def subscriber(): +def receiver_fn(): """Return an instance of mocked `Subscriber`.""" class Subscriber: @@ -43,40 +43,41 @@ def __call__(self): return Subscriber() -def test_add_rpc_subscriber(_coordinator, subscriber): - """Test the `LoopCommunicator.add_rpc_subscriber` method.""" - assert _coordinator.add_rpc_subscriber(subscriber) is not None +def test_hook_rpc_receiver(_coordinator, receiver_fn): + """Test the `LoopCommunicator.add_rpc_receiver` method.""" + assert _coordinator.hook_rpc_receiver(receiver_fn) is not None identifier = 'identifier' - assert _coordinator.add_rpc_subscriber(subscriber, identifier) == identifier + assert _coordinator.hook_rpc_receiver(receiver_fn, identifier) == identifier -def test_remove_rpc_subscriber(_coordinator, subscriber): +def test_unhook_rpc_receiver(_coordinator, receiver_fn): """Test the `LoopCommunicator.remove_rpc_subscriber` method.""" - identifier = _coordinator.add_rpc_subscriber(subscriber) - _coordinator.remove_rpc_subscriber(identifier) + identifier = _coordinator.hook_rpc_receiver(receiver_fn) + _coordinator.unhook_rpc_receiver(identifier) -def test_add_broadcast_subscriber(_coordinator, subscriber): - """Test the `LoopCommunicator.add_broadcast_subscriber` method.""" - assert _coordinator.add_broadcast_subscriber(subscriber) is not None +def test_hook_broadcast_receiver(_coordinator, receiver_fn): + """Test the coordinator hook_broadcast_receiver which calls + `LoopCommunicator.add_broadcast_subscriber` method.""" + assert _coordinator.hook_broadcast_receiver(receiver_fn) is not None identifier = 'identifier' - assert _coordinator.add_broadcast_subscriber(subscriber, identifier=identifier) == identifier + assert _coordinator.hook_broadcast_receiver(receiver_fn, identifier=identifier) == identifier -def test_remove_broadcast_subscriber(_coordinator, subscriber): +def test_unhook_broadcast_receiver(_coordinator, receiver_fn): """Test the `LoopCommunicator.remove_broadcast_subscriber` method.""" - identifier = _coordinator.add_broadcast_subscriber(subscriber) - _coordinator.remove_broadcast_subscriber(identifier) + identifier = _coordinator.hook_broadcast_receiver(receiver_fn) + _coordinator.unhook_broadcast_receiver(identifier) -def test_add_task_subscriber(_coordinator, subscriber): - """Test the `LoopCommunicator.add_task_subscriber` method.""" - assert _coordinator.add_task_subscriber(subscriber) is not None +def test_hook_task_receiver(_coordinator, receiver_fn): + """Test the hook_task_receiver calls `LoopCommunicator.add_task_subscriber` method.""" + assert _coordinator.hook_task_receiver(receiver_fn) is not None -def test_remove_task_subscriber(_coordinator, subscriber): +def test_unhook_task_receiver(_coordinator, receiver_fn): """Test the `LoopCommunicator.remove_task_subscriber` method.""" - identifier = _coordinator.add_task_subscriber(subscriber) - _coordinator.remove_task_subscriber(identifier) + identifier = _coordinator.hook_task_receiver(receiver_fn) + _coordinator.unhook_task_receiver(identifier) diff --git a/tests/rmq/test_communicator.py b/tests/rmq/test_communicator.py index 139f6434..480e820b 100644 --- a/tests/rmq/test_communicator.py +++ b/tests/rmq/test_communicator.py @@ -80,7 +80,7 @@ def get_broadcast(body, sender, subject, correlation_id): {'body': body, 'sender': sender, 'subject': subject, 'correlation_id': correlation_id} ) - _coordinator.add_broadcast_subscriber(get_broadcast) + _coordinator.hook_broadcast_receiver(get_broadcast) _coordinator.broadcast_send(**BROADCAST) result = await broadcast_future @@ -96,8 +96,8 @@ def ignore_broadcast(body, sender, subject, correlation_id): def get_broadcast(body, sender, subject, correlation_id): broadcast_future.set_result(True) - _coordinator.add_broadcast_subscriber(BroadcastFilter(ignore_broadcast, subject='other')) - _coordinator.add_broadcast_subscriber(get_broadcast) + _coordinator.hook_broadcast_receiver(BroadcastFilter(ignore_broadcast, subject='other')) + _coordinator.hook_broadcast_receiver(get_broadcast) _coordinator.broadcast_send(**{'body': 'present', 'sender': 'Martin', 'subject': 'sup', 'correlation_id': 420}) result = await broadcast_future @@ -114,7 +114,7 @@ def get_rpc(msg): assert loop is asyncio.get_event_loop() rpc_future.set_result(msg) - _coordinator.add_rpc_subscriber(get_rpc, 'rpc') + _coordinator.hook_rpc_receiver(get_rpc, 'rpc') _coordinator.rpc_send('rpc', MSG) result = await rpc_future @@ -131,7 +131,7 @@ def get_task(msg): assert loop is asyncio.get_event_loop() task_future.set_result(msg) - _coordinator.add_task_subscriber(get_task) + _coordinator.hook_task_receiver(get_task) _coordinator.task_send(TASK) # TODO: Error in the event loop log although the test pass @@ -146,7 +146,7 @@ async def test_launch(self, _coordinator, async_controller, persister): # Let the process run to the end loop = asyncio.get_event_loop() launcher = plumpy.ProcessLauncher(loop, persister=persister) - _coordinator.add_task_subscriber(launcher.call) + _coordinator.hook_task_receiver(launcher.call) result = await async_controller.launch_process(utils.DummyProcess) # Check that we got a result assert result == utils.DummyProcess.EXPECTED_OUTPUTS @@ -156,7 +156,7 @@ async def test_launch_nowait(self, _coordinator, async_controller, persister): """Testing launching but don't wait, just get the pid""" loop = asyncio.get_event_loop() launcher = plumpy.ProcessLauncher(loop, persister=persister) - _coordinator.add_task_subscriber(launcher.call) + _coordinator.hook_task_receiver(launcher.call) pid = await async_controller.launch_process(utils.DummyProcess, nowait=True) assert isinstance(pid, uuid.UUID) @@ -165,7 +165,7 @@ async def test_execute_action(self, _coordinator, async_controller, persister): """Test the process execute action""" loop = asyncio.get_event_loop() launcher = plumpy.ProcessLauncher(loop, persister=persister) - _coordinator.add_task_subscriber(launcher.call) + _coordinator.hook_task_receiver(launcher.call) result = await async_controller.execute_process(utils.DummyProcessWithOutput) assert utils.DummyProcessWithOutput.EXPECTED_OUTPUTS == result @@ -174,7 +174,7 @@ async def test_execute_action_nowait(self, _coordinator, async_controller, persi """Test the process execute action""" loop = asyncio.get_event_loop() launcher = plumpy.ProcessLauncher(loop, persister=persister) - _coordinator.add_task_subscriber(launcher.call) + _coordinator.hook_task_receiver(launcher.call) pid = await async_controller.execute_process(utils.DummyProcessWithOutput, nowait=True) assert isinstance(pid, uuid.UUID) @@ -183,7 +183,7 @@ async def test_launch_many(self, _coordinator, async_controller, persister): """Test launching multiple processes""" loop = asyncio.get_event_loop() launcher = plumpy.ProcessLauncher(loop, persister=persister) - _coordinator.add_task_subscriber(launcher.call) + _coordinator.hook_task_receiver(launcher.call) num_to_launch = 10 launch_futures = [] @@ -200,7 +200,7 @@ async def test_continue(self, _coordinator, async_controller, persister): """Test continuing a saved process""" loop = asyncio.get_event_loop() launcher = plumpy.ProcessLauncher(loop, persister=persister) - _coordinator.add_task_subscriber(launcher.call) + _coordinator.hook_task_receiver(launcher.call) process = utils.DummyProcessWithOutput() persister.save_checkpoint(process) pid = process.pid diff --git a/tests/rmq/test_process_control.py b/tests/rmq/test_process_control.py index 79a98ba3..4531f932 100644 --- a/tests/rmq/test_process_control.py +++ b/tests/rmq/test_process_control.py @@ -107,7 +107,7 @@ def test_broadcast(self, _coordinator): def on_broadcast_receive(**msg): messages.append(msg) - _coordinator.add_broadcast_subscriber(on_broadcast_receive) + _coordinator.hook_broadcast_receiver(on_broadcast_receive) proc = utils.DummyProcess(coordinator=_coordinator) proc.execute() diff --git a/tests/test_processes.py b/tests/test_processes.py index 0373b037..1f0160c3 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -1071,7 +1071,7 @@ def test_broadcast(self): def on_broadcast_receive(body, sender, subject, correlation_id): messages.append({'body': body, 'subject': subject, 'sender': sender, 'correlation_id': correlation_id}) - coordinator.add_broadcast_subscriber(on_broadcast_receive) + coordinator.hook_broadcast_receiver(on_broadcast_receive) proc = utils.DummyProcess(coordinator=coordinator) proc.execute() diff --git a/tests/utils.py b/tests/utils.py index 323a3282..c0cf0f52 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,8 +3,9 @@ import asyncio import collections +from re import Pattern import sys -from typing import Any +from typing import TYPE_CHECKING, Any, Callable, Hashable import unittest from collections.abc import Mapping import concurrent.futures @@ -16,14 +17,18 @@ from plumpy.rmq import TaskRejected import shortuuid +if TYPE_CHECKING: + ID_TYPE = Hashable + Receiver = Callable[..., Any] + Snapshot = collections.namedtuple('Snapshot', ['state', 'bundle', 'outputs']) class MockCoordinator: def __init__(self): - self._task_subscribers = {} - self._broadcast_subscribers = {} - self._rpc_subscribers = {} + self._task_receivers = {} + self._broadcast_receivers = {} + self._rpc_receivers = {} self._closed = False def is_closed(self) -> bool: @@ -33,75 +38,87 @@ def close(self): if self._closed: return self._closed = True - del self._task_subscribers - del self._broadcast_subscribers - del self._rpc_subscribers - - def add_rpc_subscriber(self, subscriber, identifier=None) -> Any: + del self._task_receivers + del self._broadcast_receivers + del self._rpc_receivers + + def hook_rpc_receiver( + self, + receiver: 'Receiver', + identifier: 'ID_TYPE | None' = None, + ) -> Any: self._ensure_open() identifier = identifier or shortuuid.uuid() - if identifier in self._rpc_subscribers: - raise RuntimeError(f"Duplicate RPC subscriber with identifier '{identifier}'") - self._rpc_subscribers[identifier] = subscriber + if identifier in self._rpc_receivers: + raise RuntimeError(f"Duplicate RPC receiver with identifier '{identifier}'") + self._rpc_receivers[identifier] = receiver return identifier - def remove_rpc_subscriber(self, identifier): + def unhook_rpc_receiver(self, identifier: 'ID_TYPE | None') -> None: self._ensure_open() try: - self._rpc_subscribers.pop(identifier) + self._rpc_receivers.pop(identifier) except KeyError as exc: - raise ValueError(f"Unknown subscriber '{identifier}'") from exc + raise ValueError(f"Unknown receiver '{identifier}'") from exc - def add_task_subscriber(self, subscriber, identifier=None): - """ - Register a task subscriber + def hook_task_receiver( + self, + receiver: 'Receiver', + identifier: 'ID_TYPE | None' = None, + ) -> 'ID_TYPE': + """Register a task receiver - :param subscriber: The task callback function - :param identifier: the subscriber identifier + :param receiver: The task callback function + :param identifier: the receiver identifier """ self._ensure_open() identifier = identifier or shortuuid.uuid() - if identifier in self._rpc_subscribers: - raise RuntimeError(f"Duplicate RPC subscriber with identifier '{identifier}'") - self._task_subscribers[identifier] = subscriber + if identifier in self._rpc_receivers: + raise RuntimeError(f"Duplicate RPC receiver with identifier '{identifier}'") + self._task_receivers[identifier] = receiver return identifier - def remove_task_subscriber(self, identifier): - """ - Remove a task subscriber + def unhook_task_receiver(self, identifier: 'ID_TYPE') -> None: + """Remove a task receiver - :param identifier: the subscriber to remove - :raises: ValueError if identifier does not correspond to a known subscriber + :param identifier: the receiver to remove + :raises: ValueError if identifier does not correspond to a known receiver """ self._ensure_open() try: - self._task_subscribers.pop(identifier) + self._task_receivers.pop(identifier) except KeyError as exception: - raise ValueError(f"Unknown subscriber: '{identifier}'") from exception - - def add_broadcast_subscriber(self, subscriber, subject_filters=None, sender_filters=None, identifier=None) -> Any: + raise ValueError(f"Unknown receiver: '{identifier}'") from exception + + def hook_broadcast_receiver( + self, + receiver: 'Receiver', + subject_filters: list[Hashable | Pattern[str]] | None = None, + sender_filters: list[Hashable | Pattern[str]] | None = None, + identifier: 'ID_TYPE | None' = None, + ) -> Any: self._ensure_open() identifier = identifier or shortuuid.uuid() - if identifier in self._broadcast_subscribers: - raise RuntimeError(f"Duplicate RPC subscriber with identifier '{identifier}'") + if identifier in self._broadcast_receivers: + raise RuntimeError(f"Duplicate RPC receiver with identifier '{identifier}'") - self._broadcast_subscribers[identifier] = subscriber + self._broadcast_receivers[identifier] = receiver return identifier - def remove_broadcast_subscriber(self, identifier): + def unhook_broadcast_receiver(self, identifier: 'ID_TYPE | None') -> None: self._ensure_open() try: - del self._broadcast_subscribers[identifier] + del self._broadcast_receivers[identifier] except KeyError as exception: - raise ValueError(f"Broadcast subscriber '{identifier}' unknown") from exception + raise ValueError(f"Broadcast receiver '{identifier}' unknown") from exception def task_send(self, msg, no_reply=False): self._ensure_open() future = concurrent.futures.Future() - for subscriber in self._task_subscribers.values(): + for receiver in self._task_receivers.values(): try: - result = subscriber(self, msg) + result = receiver(self, msg) future.set_result(result) break except TaskRejected: @@ -118,13 +135,13 @@ def task_send(self, msg, no_reply=False): def rpc_send(self, recipient_id, msg): self._ensure_open() try: - subscriber = self._rpc_subscribers[recipient_id] + receiver = self._rpc_receivers[recipient_id] except KeyError as exception: raise RuntimeError(f"Unknown rpc recipient '{recipient_id}'") from exception else: future = concurrent.futures.Future() try: - future.set_result(subscriber(self, msg)) + future.set_result(receiver(self, msg)) except Exception: future.set_exception(RuntimeError(sys.exc_info())) @@ -132,8 +149,8 @@ def rpc_send(self, recipient_id, msg): def broadcast_send(self, body, sender=None, subject=None, correlation_id=None): self._ensure_open() - for subscriber in self._broadcast_subscribers.values(): - subscriber(body=body, sender=sender, subject=subject, correlation_id=correlation_id) + for receiver in self._broadcast_receivers.values(): + receiver(body=body, sender=sender, subject=subject, correlation_id=correlation_id) return True def _ensure_open(self): From b1e4e1ebf28627cccdb6edc269d73db55dabc344 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Fri, 21 Feb 2025 23:53:57 +0100 Subject: [PATCH 21/22] ralex - T -> CommT in generic type at communications.py --- src/plumpy/rmq/communications.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/plumpy/rmq/communications.py b/src/plumpy/rmq/communications.py index 50927557..3ed3877a 100644 --- a/src/plumpy/rmq/communications.py +++ b/src/plumpy/rmq/communications.py @@ -79,10 +79,10 @@ def converted(communicator: kiwipy.Communicator, *args: Any, **kwargs: Any) -> k return converted -T = TypeVar('T', bound=kiwipy.Communicator) +CommT = TypeVar('CommT', bound=kiwipy.Communicator) -def wrap_communicator(communicator: T, loop: Optional[asyncio.AbstractEventLoop] = None) -> 'LoopCommunicator[T]': +def wrap_communicator(communicator: CommT, loop: Optional[asyncio.AbstractEventLoop] = None) -> 'LoopCommunicator[CommT]': """ Wrap a communicator such that all callbacks made to any subscribers are scheduled on the given event loop. @@ -103,10 +103,10 @@ def wrap_communicator(communicator: T, loop: Optional[asyncio.AbstractEventLoop] @final -class LoopCommunicator(Generic[T], kiwipy.Communicator): # type: ignore +class LoopCommunicator(Generic[CommT], kiwipy.Communicator): # type: ignore """Wrapper around a `kiwipy.Communicator` that schedules any subscriber messages on a given event loop.""" - def __init__(self, communicator: T, loop: Optional[asyncio.AbstractEventLoop] = None): + def __init__(self, communicator: CommT, loop: Optional[asyncio.AbstractEventLoop] = None): """ :param communicator: The kiwipy communicator :param loop: The event loop to schedule callbacks on @@ -118,7 +118,7 @@ def __init__(self, communicator: T, loop: Optional[asyncio.AbstractEventLoop] = self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() @property - def inner(self) -> T: + def inner(self) -> CommT: return self._communicator def loop(self) -> asyncio.AbstractEventLoop: From 14c6348403bbb5f29b0a2792b9f66f585069a8b0 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Fri, 21 Feb 2025 23:59:33 +0100 Subject: [PATCH 22/22] misc: pre-commit / uv.lock Explicitly declare protocol implementations --- src/plumpy/broadcast_filter.py | 1 + src/plumpy/controller.py | 8 +- src/plumpy/coordinator.py | 9 +- src/plumpy/message.py | 1 - src/plumpy/processes.py | 35 ++--- src/plumpy/rmq/communications.py | 4 +- src/plumpy/rmq/process_control.py | 8 +- tests/rmq/__init__.py | 3 +- tests/rmq/test_coordinator.py | 7 + tests/rmq/test_process_control.py | 5 + tests/test_coordinator.py | 7 + tests/utils.py | 6 +- uv.lock | 209 +++++++++++++----------------- 13 files changed, 148 insertions(+), 155 deletions(-) create mode 100644 tests/rmq/test_coordinator.py create mode 100644 tests/test_coordinator.py diff --git a/src/plumpy/broadcast_filter.py b/src/plumpy/broadcast_filter.py index 61b27095..6ec6c41e 100644 --- a/src/plumpy/broadcast_filter.py +++ b/src/plumpy/broadcast_filter.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +# type: ignore import re import typing diff --git a/src/plumpy/controller.py b/src/plumpy/controller.py index 9f2793a0..34207a4a 100644 --- a/src/plumpy/controller.py +++ b/src/plumpy/controller.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Any, Hashable, Optional, Protocol, Union +from typing import Any, Hashable, Optional, Protocol, Union, runtime_checkable from plumpy import loaders from plumpy.message import MessageType @@ -12,6 +12,7 @@ ProcessStatus = Any +@runtime_checkable class ProcessController(Protocol): """ Control processes using coroutines that will send messages and wait @@ -26,7 +27,7 @@ def get_status(self, pid: 'PID_TYPE') -> ProcessStatus: """ ... - def pause_process(self, pid: 'PID_TYPE', msg: str | None = None) -> ProcessResult: + def pause_process(self, pid: 'PID_TYPE', msg_text: str | None = None) -> Any: """ Pause the process @@ -52,8 +53,7 @@ def play_process(self, pid: 'PID_TYPE') -> ProcessResult: ... def play_all(self) -> None: - """Play all processes that are subscribed to the same coordinator - """ + """Play all processes that are subscribed to the same coordinator""" def kill_process(self, pid: 'PID_TYPE', msg_text: str | None = None) -> Any: """Kill the process diff --git a/src/plumpy/coordinator.py b/src/plumpy/coordinator.py index e647961e..ab97cb20 100644 --- a/src/plumpy/coordinator.py +++ b/src/plumpy/coordinator.py @@ -1,14 +1,15 @@ # -*- coding: utf-8 -*- from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Hashable, Protocol from re import Pattern +from typing import TYPE_CHECKING, Any, Callable, Hashable, Protocol, runtime_checkable if TYPE_CHECKING: ID_TYPE = Hashable Receiver = Callable[..., Any] +@runtime_checkable class Coordinator(Protocol): def hook_rpc_receiver( self, @@ -36,7 +37,11 @@ def unhook_broadcast_receiver(self, identifier: 'ID_TYPE | None') -> None: ... def unhook_task_receiver(self, identifier: 'ID_TYPE') -> None: ... - def rpc_send(self, recipient_id: Hashable, msg: Any,) -> Any: ... + def rpc_send( + self, + recipient_id: Hashable, + msg: Any, + ) -> Any: ... def broadcast_send( self, diff --git a/src/plumpy/message.py b/src/plumpy/message.py index 098277e1..99d215ff 100644 --- a/src/plumpy/message.py +++ b/src/plumpy/message.py @@ -7,7 +7,6 @@ import logging from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast -from plumpy.coordinator import Coordinator from plumpy.exceptions import PersistenceError, TaskRejectedError from . import loaders, persistence diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 75737574..1852af3d 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -34,9 +34,7 @@ cast, ) -import kiwipy - -from plumpy.broadcast_filter import BroadcastFilter +from plumpy.broadcast_filter import BroadcastFilter # type: ignore from plumpy.coordinator import Coordinator try: @@ -944,19 +942,16 @@ def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> Non # region Communication def message_receive(self, msg: MessageType) -> Any: - """ - Coroutine called when the process receives a message from the communicator + """Coroutine called when the process receives a message from the communicator - :param _comm: the communicator that sent the message :param msg: the message :return: the outcome of processing the message, the return value will be sent back as a response to the sender """ - # self.logger.debug( - # "Process<%s>: received RPC message with communicator '%s': %r", - # self.pid, - # _comm, - # msg, - # ) + self.logger.debug( + 'Process<%s>: received RPC message: %r', + self.pid, + msg, + ) intent = msg[message.INTENT_KEY] @@ -977,19 +972,17 @@ def message_receive(self, msg: MessageType) -> Any: def broadcast_receive( self, msg: MessageType, sender: Any, subject: Any, correlation_id: Any ) -> Optional[concurrent.futures.Future]: - """ - Coroutine called when the process receives a message from the communicator + """Coroutine called when the process receives a message from the communicator :param msg: the message """ + self.logger.debug( + "Process<%s>: received broadcast message '%s': %r", + self.pid, + subject, + msg, + ) - # self.logger.debug( - # "Process<%s>: received broadcast message '%s' with communicator '%s': %r", - # self.pid, - # subject, - # _comm, - # msg, - # ) # If we get a message we recognise then action it, otherwise ignore fn = None if subject == message.Intent.PLAY: diff --git a/src/plumpy/rmq/communications.py b/src/plumpy/rmq/communications.py index 3ed3877a..e9f227b0 100644 --- a/src/plumpy/rmq/communications.py +++ b/src/plumpy/rmq/communications.py @@ -82,7 +82,9 @@ def converted(communicator: kiwipy.Communicator, *args: Any, **kwargs: Any) -> k CommT = TypeVar('CommT', bound=kiwipy.Communicator) -def wrap_communicator(communicator: CommT, loop: Optional[asyncio.AbstractEventLoop] = None) -> 'LoopCommunicator[CommT]': +def wrap_communicator( + communicator: CommT, loop: Optional[asyncio.AbstractEventLoop] = None +) -> 'LoopCommunicator[CommT]': """ Wrap a communicator such that all callbacks made to any subscribers are scheduled on the given event loop. diff --git a/src/plumpy/rmq/process_control.py b/src/plumpy/rmq/process_control.py index 02eb8853..9a3415b5 100644 --- a/src/plumpy/rmq/process_control.py +++ b/src/plumpy/rmq/process_control.py @@ -9,6 +9,7 @@ import kiwipy from plumpy import loaders +from plumpy.controller import ProcessController from plumpy.coordinator import Coordinator from plumpy.message import ( Intent, @@ -29,8 +30,7 @@ ProcessStatus = Any -# FIXME: the class not fit typing of ProcessController protocol -class RemoteProcessController: +class RemoteProcessController(ProcessController): """ Control remote processes using coroutines that will send messages and wait (in a non-blocking way) for their response @@ -190,7 +190,7 @@ async def execute_process( return result -class RemoteProcessThreadController: +class RemoteProcessThreadController(ProcessController): """ A class that can be used to control and launch remote processes """ @@ -212,7 +212,7 @@ def get_status(self, pid: 'PID_TYPE') -> kiwipy.Future: """ return self._coordinator.rpc_send(pid, MessageBuilder.status()) - def pause_process(self, pid: 'PID_TYPE', msg_text: str | None = None) -> kiwipy.Future: + def pause_process(self, pid: 'PID_TYPE', msg_text: str | None = None) -> Any: """Pause the process :param pid: the pid of the process to pause diff --git a/tests/rmq/__init__.py b/tests/rmq/__init__.py index 91af1549..6e59a7a8 100644 --- a/tests/rmq/__init__.py +++ b/tests/rmq/__init__.py @@ -5,6 +5,7 @@ import kiwipy import concurrent.futures +from plumpy.coordinator import Coordinator from plumpy.exceptions import CoordinatorConnectionError if TYPE_CHECKING: @@ -15,7 +16,7 @@ @final -class RmqCoordinator(Generic[U]): +class RmqCoordinator(Coordinator, Generic[U]): def __init__(self, comm: U): self._comm = comm diff --git a/tests/rmq/test_coordinator.py b/tests/rmq/test_coordinator.py new file mode 100644 index 00000000..ce242bda --- /dev/null +++ b/tests/rmq/test_coordinator.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- +from plumpy.coordinator import Coordinator +from . import RmqCoordinator + + +def test_mock_coordinator(): + assert isinstance(RmqCoordinator, Coordinator) diff --git a/tests/rmq/test_process_control.py b/tests/rmq/test_process_control.py index 4531f932..3c69272e 100644 --- a/tests/rmq/test_process_control.py +++ b/tests/rmq/test_process_control.py @@ -7,6 +7,7 @@ from kiwipy import rmq import plumpy +from plumpy.controller import ProcessController from plumpy.rmq import process_control from . import RmqCoordinator @@ -42,6 +43,10 @@ def async_controller(_coordinator): def sync_controller(_coordinator): yield process_control.RemoteProcessThreadController(_coordinator) +def test_remote_process_controller(sync_controller, async_controller): + assert isinstance(sync_controller, ProcessController) + assert isinstance(async_controller, ProcessController) + class TestRemoteProcessController: @pytest.mark.asyncio diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py new file mode 100644 index 00000000..88d295f8 --- /dev/null +++ b/tests/test_coordinator.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- +from plumpy.coordinator import Coordinator +from .utils import MockCoordinator + + +def test_mock_coordinator(): + assert isinstance(MockCoordinator, Coordinator) diff --git a/tests/utils.py b/tests/utils.py index c0cf0f52..bc969ca4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Utilities for tests""" +from __future__ import annotations import asyncio import collections @@ -12,6 +13,7 @@ import plumpy from plumpy import persistence, process_states, processes, utils +from plumpy.coordinator import Coordinator from plumpy.exceptions import CoordinatorConnectionError from plumpy.message import MessageBuilder from plumpy.rmq import TaskRejected @@ -24,7 +26,7 @@ Snapshot = collections.namedtuple('Snapshot', ['state', 'bundle', 'outputs']) -class MockCoordinator: +class MockCoordinator(Coordinator): def __init__(self): self._task_receivers = {} self._broadcast_receivers = {} @@ -105,7 +107,7 @@ def hook_broadcast_receiver( self._broadcast_receivers[identifier] = receiver return identifier - def unhook_broadcast_receiver(self, identifier: 'ID_TYPE | None') -> None: + def unhook_broadcast_receiver(self, identifier: 'ID_TYPE | None') -> None: self._ensure_open() try: del self._broadcast_receivers[identifier] diff --git a/uv.lock b/uv.lock index a586981f..ef84ff2a 100644 --- a/uv.lock +++ b/uv.lock @@ -6,6 +6,18 @@ resolution-markers = [ "python_full_version < '3.10'", ] +[[package]] +name = "accessible-pygments" +version = "0.0.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bc/c1/bbac6a50d02774f91572938964c582fff4270eee73ab822a4aeea4d8b11b/accessible_pygments-0.0.5.tar.gz", hash = "sha256:40918d3e6a2b619ad424cb91e556bd3bd8865443d9f22f1dcdf79e33c8046872", size = 1377899 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8d/3f/95338030883d8c8b91223b4e21744b04d11b161a3ef117295d8241f50ab4/accessible_pygments-0.0.5-py3-none-any.whl", hash = "sha256:88ae3211e68a1d0b011504b2ffc1691feafce124b845bd072ab6f9f66f34d4b7", size = 1395903 }, +] + [[package]] name = "aio-pika" version = "9.4.3" @@ -489,26 +501,11 @@ wheels = [ [[package]] name = "docutils" -version = "0.17.1" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.10'", -] -sdist = { url = "https://files.pythonhosted.org/packages/4c/17/559b4d020f4b46e0287a2eddf2d8ebf76318fd3bd495f1625414b052fdc9/docutils-0.17.1.tar.gz", hash = "sha256:686577d2e4c32380bb50cbb22f575ed742d58168cee37e99117a854bcd88f125", size = 2016138 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4c/5e/6003a0d1f37725ec2ebd4046b657abb9372202655f96e76795dca8c0063c/docutils-0.17.1-py2.py3-none-any.whl", hash = "sha256:cf316c8370a737a022b72b56874f6602acf974a37a9fba42ec2876387549fc61", size = 575533 }, -] - -[[package]] -name = "docutils" -version = "0.21.2" +version = "0.20.1" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version < '3.10'", -] -sdist = { url = "https://files.pythonhosted.org/packages/ae/ed/aefcc8cd0ba62a0560c3c18c33925362d46c6075480bfa4df87b28e169a9/docutils-0.21.2.tar.gz", hash = "sha256:3a6b18732edf182daa3cd12775bbb338cf5691468f91eeeb109deff6ebfa986f", size = 2204444 } +sdist = { url = "https://files.pythonhosted.org/packages/1f/53/a5da4f2c5739cf66290fac1431ee52aff6851c7c8ffd8264f13affd7bcdd/docutils-0.20.1.tar.gz", hash = "sha256:f08a4e276c3a1583a86dce3e34aba3fe04d02bba2dd51ed16106244e8a923e3b", size = 2058365 } wheels = [ - { url = "https://files.pythonhosted.org/packages/8f/d7/9322c609343d929e75e7e5e6255e614fcc67572cfd083959cdef3b7aad79/docutils-0.21.2-py3-none-any.whl", hash = "sha256:dafca5b9e384f0e419294eb4d2ff9fa826435bf15f15b7bd45723e8ad76811b2", size = 587408 }, + { url = "https://files.pythonhosted.org/packages/26/87/f238c0670b94533ac0353a4e2a1a771a0cc73277b88bff23d3ae35a256c1/docutils-0.20.1-py3-none-any.whl", hash = "sha256:96f387a2c5562db4476f09f13bbab2192e764cac08ebbf3a34a95d9b1e4a59d6", size = 572666 }, ] [[package]] @@ -785,14 +782,14 @@ wheels = [ [[package]] name = "jinja2" -version = "2.11.3" +version = "3.1.5" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "markupsafe" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4f/e7/65300e6b32e69768ded990494809106f87da1d436418d5f1367ed3966fd7/Jinja2-2.11.3.tar.gz", hash = "sha256:a6d58433de0ae800347cab1fa3043cebbabe8baa9d29e668f1c768cb87a333c6", size = 257589 } +sdist = { url = "https://files.pythonhosted.org/packages/af/92/b3130cbbf5591acf9ade8708c365f3238046ac7cb8ccba6e81abccb0ccff/jinja2-3.1.5.tar.gz", hash = "sha256:8fefff8dc3034e27bb80d67c671eb8a9bc424c0ef4c0826edbff304cceff43bb", size = 244674 } wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/c2/1eece8c95ddbc9b1aeb64f5783a9e07a286de42191b7204d67b7496ddf35/Jinja2-2.11.3-py2.py3-none-any.whl", hash = "sha256:03e47ad063331dd6a3f04a43eddca8a966a26ba0c5b7207a9a9e4e08f1b29419", size = 125699 }, + { url = "https://files.pythonhosted.org/packages/bd/0f/2ba5fbcd631e3e88689309dbe978c5769e883e4b84ebfe7da30b43275c5a/jinja2-3.1.5-py3-none-any.whl", hash = "sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb", size = 134596 }, ] [[package]] @@ -836,18 +833,21 @@ wheels = [ [[package]] name = "jupyter-cache" -version = "0.4.3" +version = "1.0.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "attrs" }, + { name = "click" }, + { name = "importlib-metadata" }, { name = "nbclient" }, - { name = "nbdime" }, { name = "nbformat" }, + { name = "pyyaml" }, { name = "sqlalchemy" }, + { name = "tabulate" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/49/cd/43a393cd0e5a5019598bf899c3ccfac4b8ac92b6b47d25980a44cc1a3ec3/jupyter-cache-0.4.3.tar.gz", hash = "sha256:4c9b5431b1d320bc68440c21fa0a155bbeb29c5b979bef72222e244a7bcd54fc", size = 29068 } +sdist = { url = "https://files.pythonhosted.org/packages/bb/f7/3627358075f183956e8c4974603232b03afd4ddc7baf72c2bc9fff522291/jupyter_cache-1.0.1.tar.gz", hash = "sha256:16e808eb19e3fb67a223db906e131ea6e01f03aa27f49a7214ce6a5fec186fb9", size = 32048 } wheels = [ - { url = "https://files.pythonhosted.org/packages/01/b1/5be5e126e5afb004a487443b21f5f39642f471323ca80ac17b1edd62696a/jupyter_cache-0.4.3-py3-none-any.whl", hash = "sha256:6d5d662d81f565d18009e8dcfd3a56fb876af47eafead2a19ef0045aba8ffe3b", size = 31668 }, + { url = "https://files.pythonhosted.org/packages/64/6b/67b87da9d36bff9df7d0efbd1a325fa372a43be7158effaf43ed7b22341d/jupyter_cache-1.0.1-py3-none-any.whl", hash = "sha256:9c3cafd825ba7da8b5830485343091143dff903e4d8c69db9349b728b140abf6", size = 33907 }, ] [[package]] @@ -936,22 +936,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7d/77/6a98cc88f1061c0206b427b602efb6fcb9bc369e958aee11676d5cfc4412/jupyter_server_mathjax-0.2.6-py3-none-any.whl", hash = "sha256:416389dde2010df46d5fbbb7adb087a5607111070af65a1445391040f2babb5e", size = 3120990 }, ] -[[package]] -name = "jupyter-sphinx" -version = "0.3.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "ipython" }, - { name = "ipywidgets" }, - { name = "nbconvert" }, - { name = "nbformat" }, - { name = "sphinx" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/5c/b0/cc381afa960b7af1b4abac58abbedc0fd93d8805d422acd5d2b26682744f/jupyter_sphinx-0.3.1.tar.gz", hash = "sha256:c4caf8bbf2be6edfe0319aa76127d17fdbe6927c8189cda2d6ac59c01f38404b", size = 16686 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/5d/0f/3fedf88d1e5ac7b74e26a0f99f3e1c242e45484c5c0a7487b51e151d09f2/jupyter_sphinx-0.3.1-py3-none-any.whl", hash = "sha256:56f4cd319b96c491c61bfa9d11a2ee452d2758beecbd2723b23916aaac4c2bab", size = 19781 }, -] - [[package]] name = "jupyterlab" version = "3.3.4" @@ -1015,8 +999,7 @@ wheels = [ [package.optional-dependencies] docs = [ - { name = "docutils", version = "0.17.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, - { name = "docutils", version = "0.21.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "docutils" }, { name = "jupyter" }, { name = "nbsphinx", version = "0.9.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "nbsphinx", version = "0.9.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, @@ -1032,15 +1015,14 @@ rmq = [ [[package]] name = "markdown-it-py" -version = "0.6.2" +version = "3.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "attrs" }, - { name = "mdit-py-plugins" }, + { name = "mdurl" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0e/c0/8b6e358df933d68c7cc7202ed454eab6a411b792796646e4ced4a998a47d/markdown-it-py-0.6.2.tar.gz", hash = "sha256:c3b9f995be0792cbbc8ab2f53d74072eb7ff8a8b622be8d61d38ab879709eca3", size = 55904 } +sdist = { url = "https://files.pythonhosted.org/packages/38/71/3b932df36c1a044d397a1f92d1cf91ee0a503d91e470cbd670aa66b07ed0/markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb", size = 74596 } wheels = [ - { url = "https://files.pythonhosted.org/packages/2e/cb/8493188845d26599170268bb0e0a63e75584d5e7f130488c641e96449cd7/markdown_it_py-0.6.2-py3-none-any.whl", hash = "sha256:30b3e9f8198dc82a5df0dcb73fd31d56cd9a43bf8a747feb10b2ba74f962bcb1", size = 81687 }, + { url = "https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1", size = 87528 }, ] [[package]] @@ -1090,14 +1072,23 @@ wheels = [ [[package]] name = "mdit-py-plugins" -version = "0.2.6" +version = "0.4.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "markdown-it-py" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e1/28/1ba872560eef4bd28873c53e1ecd63fc70ca971054055e18c1c891576901/mdit-py-plugins-0.2.6.tar.gz", hash = "sha256:1e467ca2ea056e8065cbd5d6c61e5052bb50826bde84c40f6a5ed77e82125710", size = 25661 } +sdist = { url = "https://files.pythonhosted.org/packages/19/03/a2ecab526543b152300717cf232bb4bb8605b6edb946c845016fa9c9c9fd/mdit_py_plugins-0.4.2.tar.gz", hash = "sha256:5f2cd1fdb606ddf152d37ec30e46101a60512bc0e5fa1a7002c36647b09e26b5", size = 43542 } wheels = [ - { url = "https://files.pythonhosted.org/packages/0c/31/f0ecaccf7cd2db17332a94852f190840167c3cb7eadf09efe498412f909a/mdit_py_plugins-0.2.6-py3-none-any.whl", hash = "sha256:77fd75dad81109ee91f30eb49146196f79afbbae041f298ae4886c8c2b5e23d7", size = 39287 }, + { url = "https://files.pythonhosted.org/packages/a7/f7/7782a043553ee469c1ff49cfa1cdace2d6bf99a1f333cf38676b3ddf30da/mdit_py_plugins-0.4.2-py3-none-any.whl", hash = "sha256:0c673c3f889399a33b95e88d2f0d111b4447bdfea7f237dab2d488f459835636", size = 55316 }, +] + +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 }, ] [[package]] @@ -1246,67 +1237,64 @@ wheels = [ [[package]] name = "myst-nb" -version = "0.11.1" +version = "1.2.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "docutils", version = "0.17.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, - { name = "docutils", version = "0.21.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "importlib-metadata" }, + { name = "ipykernel" }, { name = "ipython" }, - { name = "ipywidgets" }, { name = "jupyter-cache" }, - { name = "jupyter-sphinx" }, - { name = "myst-parser", version = "0.13.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, - { name = "myst-parser", version = "0.13.7", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, - { name = "nbconvert" }, + { name = "myst-parser", version = "3.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "myst-parser", version = "4.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "nbclient" }, { name = "nbformat" }, { name = "pyyaml" }, { name = "sphinx" }, - { name = "sphinx-togglebutton" }, + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/fa/9f/06bc463c1bcbfb7b48b56240d4fd7a8cefccfc510c7cfde77b8c07bb7fe4/myst-nb-0.11.1.tar.gz", hash = "sha256:1ac530645296310c61ccb7e767309c6498fa386ccc41499f5ec1f6b57a4dd1c9", size = 31598 } +sdist = { url = "https://files.pythonhosted.org/packages/ae/8f/71d983ed85b1aff17db25e447a9beb67b50a9116c7cff5cde26796d1ffd0/myst_nb-1.2.0.tar.gz", hash = "sha256:af459ec753b341952182b45b0a80b4776cebf80c9ee6aaca2a3f4027b440c9de", size = 79446 } wheels = [ - { url = "https://files.pythonhosted.org/packages/a5/b0/350f2e4b9f21a58d87e93457bdcda89a62b1657e7ddf93c11c059caf6cbe/myst_nb-0.11.1-py3-none-any.whl", hash = "sha256:f009fc7552b425be2250476c92a0e07a5c6f12a27755f265fc2bc5be511a47a6", size = 36654 }, + { url = "https://files.pythonhosted.org/packages/40/98/fa9dee0caf4e1f2e895d047952bf84a64eb95102df14c82c20594c0afa5f/myst_nb-1.2.0-py3-none-any.whl", hash = "sha256:0e09909877848c0cf45e1aecee97481512efa29a0c4caa37870a03bba11c56c1", size = 80303 }, ] [[package]] name = "myst-parser" -version = "0.13.6" +version = "3.0.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version < '3.10'", ] dependencies = [ - { name = "docutils", version = "0.21.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "docutils", marker = "python_full_version < '3.10'" }, { name = "jinja2", marker = "python_full_version < '3.10'" }, { name = "markdown-it-py", marker = "python_full_version < '3.10'" }, { name = "mdit-py-plugins", marker = "python_full_version < '3.10'" }, { name = "pyyaml", marker = "python_full_version < '3.10'" }, { name = "sphinx", marker = "python_full_version < '3.10'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/be/c7/644c475014b7e0c1ac625a9412a0a3f9b1dbb354d43ed12000d3ac8073f8/myst-parser-0.13.6.tar.gz", hash = "sha256:bec01ecebe9b9c04322f8aebd6fd8e61d2cb9ab711d531065a374cc3dcb1d7be", size = 43824 } +sdist = { url = "https://files.pythonhosted.org/packages/49/64/e2f13dac02f599980798c01156393b781aec983b52a6e4057ee58f07c43a/myst_parser-3.0.1.tar.gz", hash = "sha256:88f0cb406cb363b077d176b51c476f62d60604d68a8dcdf4832e080441301a87", size = 92392 } wheels = [ - { url = "https://files.pythonhosted.org/packages/a6/dc/0a77028b5b7bf8661e1c73569b72f2b822e4d7a570a34d29d01ac789b626/myst_parser-0.13.6-py3-none-any.whl", hash = "sha256:a448b3dcb39bc62a6954f5e18544b83d69ed69d8947cf01f8ebe8b654921b4bf", size = 43785 }, + { url = "https://files.pythonhosted.org/packages/e2/de/21aa8394f16add8f7427f0a1326ccd2b3a2a8a3245c9252bc5ac034c6155/myst_parser-3.0.1-py3-none-any.whl", hash = "sha256:6457aaa33a5d474aca678b8ead9b3dc298e89c68e67012e73146ea6fd54babf1", size = 83163 }, ] [[package]] name = "myst-parser" -version = "0.13.7" +version = "4.0.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.10'", ] dependencies = [ - { name = "docutils", version = "0.17.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "docutils", marker = "python_full_version >= '3.10'" }, { name = "jinja2", marker = "python_full_version >= '3.10'" }, { name = "markdown-it-py", marker = "python_full_version >= '3.10'" }, { name = "mdit-py-plugins", marker = "python_full_version >= '3.10'" }, { name = "pyyaml", marker = "python_full_version >= '3.10'" }, { name = "sphinx", marker = "python_full_version >= '3.10'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a6/af/e7c4c8634bf90664efa6ab3d550dc6c526b59b2990e9a4bdd192f8edc4aa/myst-parser-0.13.7.tar.gz", hash = "sha256:e4bc99e43e19f70d22e528de8e7cce59f7e8e7c4c34dcba203de92de7a7c7c85", size = 45618 } +sdist = { url = "https://files.pythonhosted.org/packages/66/a5/9626ba4f73555b3735ad86247a8077d4603aa8628537687c839ab08bfe44/myst_parser-4.0.1.tar.gz", hash = "sha256:5cfea715e4f3574138aecbf7d54132296bfd72bb614d31168f48c477a830a7c4", size = 93985 } wheels = [ - { url = "https://files.pythonhosted.org/packages/2c/40/db9563e8b57710ea9742b74e5228a4bcb8130aceeeab71f8315ca79a7b57/myst_parser-0.13.7-py3-none-any.whl", hash = "sha256:260355b4da8e8865fe080b0638d7f1ab1791dc4bed02a7a48630b6bad4249219", size = 44007 }, + { url = "https://files.pythonhosted.org/packages/5f/df/76d0321c3797b54b60fef9ec3bd6f4cfd124b9e422182156a1dd418722cf/myst_parser-4.0.1-py3-none-any.whl", hash = "sha256:9134e88959ec3b5780aedf8a99680ea242869d012e8821db3126d427edc9c95d", size = 84579 }, ] [[package]] @@ -1417,7 +1405,7 @@ resolution-markers = [ "python_full_version >= '3.10'", ] dependencies = [ - { name = "docutils", version = "0.17.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "docutils", marker = "python_full_version >= '3.10'" }, { name = "jinja2", marker = "python_full_version >= '3.10'" }, { name = "nbconvert", marker = "python_full_version >= '3.10'" }, { name = "nbformat", marker = "python_full_version >= '3.10'" }, @@ -1437,7 +1425,7 @@ resolution-markers = [ "python_full_version < '3.10'", ] dependencies = [ - { name = "docutils", version = "0.21.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "docutils", marker = "python_full_version < '3.10'" }, { name = "jinja2", marker = "python_full_version < '3.10'" }, { name = "nbconvert", marker = "python_full_version < '3.10'" }, { name = "nbformat", marker = "python_full_version < '3.10'" }, @@ -1616,7 +1604,6 @@ dependencies = [ docs = [ { name = "importlib-metadata" }, { name = "ipython" }, - { name = "jinja2" }, { name = "kiwipy", extra = ["docs"] }, { name = "markupsafe" }, { name = "myst-nb" }, @@ -1644,12 +1631,11 @@ requires-dist = [ { name = "importlib-resources", marker = "extra == 'tests'", specifier = "~=5.2" }, { name = "ipykernel", marker = "extra == 'tests'", specifier = "==6.12.1" }, { name = "ipython", marker = "extra == 'docs'", specifier = "~=7.0" }, - { name = "jinja2", marker = "extra == 'docs'", specifier = "==2.11.3" }, { name = "kiwipy", extras = ["docs"], marker = "extra == 'docs'", specifier = "~=0.8.3" }, { name = "kiwipy", extras = ["rmq"], specifier = "~=0.8.5" }, { name = "markupsafe", marker = "extra == 'docs'", specifier = "==2.0.1" }, { name = "mypy", marker = "extra == 'pre-commit'", specifier = "==1.13.0" }, - { name = "myst-nb", marker = "extra == 'docs'", specifier = "~=0.11.0" }, + { name = "myst-nb", marker = "extra == 'docs'", specifier = "~=1.2.0" }, { name = "nest-asyncio", specifier = "~=1.5,>=1.5.1" }, { name = "pre-commit", marker = "extra == 'pre-commit'", specifier = "~=2.2" }, { name = "pytest", marker = "extra == 'tests'", specifier = "~=7.0" }, @@ -1658,8 +1644,8 @@ requires-dist = [ { name = "pytest-notebook", marker = "extra == 'tests'", specifier = ">=0.8.0" }, { name = "pyyaml", specifier = "~=6.0" }, { name = "shortuuid", marker = "extra == 'tests'", specifier = "==1.0.8" }, - { name = "sphinx", marker = "extra == 'docs'", specifier = "~=3.2.0" }, - { name = "sphinx-book-theme", marker = "extra == 'docs'", specifier = "~=0.0.39" }, + { name = "sphinx", marker = "extra == 'docs'", specifier = "~=7.2.0" }, + { name = "sphinx-book-theme", marker = "extra == 'docs'", specifier = "~=1.1.4" }, { name = "types-pyyaml", marker = "extra == 'pre-commit'" }, ] provides-extras = ["docs", "pre-commit", "tests"] @@ -1850,14 +1836,21 @@ wheels = [ [[package]] name = "pydata-sphinx-theme" -version = "0.4.3" +version = "0.15.4" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "accessible-pygments" }, + { name = "babel" }, + { name = "beautifulsoup4" }, + { name = "docutils" }, + { name = "packaging" }, + { name = "pygments" }, { name = "sphinx" }, + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e6/4a/01439756d28d0d1b4af1fa347efeff73f6f4e64c8b5132325cc3c0862d03/pydata-sphinx-theme-0.4.3.tar.gz", hash = "sha256:8cf8fbc74c6c47d6ed497a91f3bedf94d57383b52eebb4fa05ae7fc4f50767a2", size = 2141791 } +sdist = { url = "https://files.pythonhosted.org/packages/67/ea/3ab478cccacc2e8ef69892c42c44ae547bae089f356c4b47caf61730958d/pydata_sphinx_theme-0.15.4.tar.gz", hash = "sha256:7762ec0ac59df3acecf49fd2f889e1b4565dbce8b88b2e29ee06fdd90645a06d", size = 2400673 } wheels = [ - { url = "https://files.pythonhosted.org/packages/f8/7f/b11e6bd6d1a8419b29b54b0f2594f879cf766b834acce8df2bcd9fed301b/pydata_sphinx_theme-0.4.3-py3-none-any.whl", hash = "sha256:aa0ae055de5de36a637387941d0e18d8dad35d97d56f0faf25f5219658d49df2", size = 2144730 }, + { url = "https://files.pythonhosted.org/packages/e7/d3/c622950d87a2ffd1654208733b5bd1c5645930014abed8f4c0d74863988b/pydata_sphinx_theme-0.15.4-py3-none-any.whl", hash = "sha256:2136ad0e9500d0949f96167e63f3e298620040aea8f9c74621959eda5d4cf8e6", size = 4640157 }, ] [[package]] @@ -2252,20 +2245,19 @@ wheels = [ [[package]] name = "sphinx" -version = "3.2.1" +version = "7.2.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "alabaster" }, { name = "babel" }, { name = "colorama", marker = "sys_platform == 'win32'" }, - { name = "docutils", version = "0.17.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, - { name = "docutils", version = "0.21.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "docutils" }, { name = "imagesize" }, + { name = "importlib-metadata", marker = "python_full_version < '3.10'" }, { name = "jinja2" }, { name = "packaging" }, { name = "pygments" }, { name = "requests" }, - { name = "setuptools" }, { name = "snowballstemmer" }, { name = "sphinxcontrib-applehelp" }, { name = "sphinxcontrib-devhelp" }, @@ -2274,9 +2266,9 @@ dependencies = [ { name = "sphinxcontrib-qthelp" }, { name = "sphinxcontrib-serializinghtml" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/57/58/48268b16bf3e6e8288c4c6f3d500e4dd1ca0210289a5be8366bd6d2e6088/Sphinx-3.2.1.tar.gz", hash = "sha256:321d6d9b16fa381a5306e5a0b76cd48ffbc588e6340059a729c6fdd66087e0e8", size = 5970067 } +sdist = { url = "https://files.pythonhosted.org/packages/73/8e/6e51da4b26665b4b92b1944ea18b2d9c825e753e19180cc5bdc818d0ed3b/sphinx-7.2.6.tar.gz", hash = "sha256:9a5160e1ea90688d5963ba09a2dcd8bdd526620edbb65c328728f1b2228d5ab5", size = 7015183 } wheels = [ - { url = "https://files.pythonhosted.org/packages/63/b8/34ba32a94cb2b223b941e43b3bcab11281763b95daa8587879eec1eb9a62/Sphinx-3.2.1-py3-none-any.whl", hash = "sha256:ce6fd7ff5b215af39e2fcd44d4a321f6694b4530b6f2b2109b64d120773faea0", size = 2868133 }, + { url = "https://files.pythonhosted.org/packages/b2/b6/8ed35256aa530a9d3da15d20bdc0ba888d5364441bb50a5a83ee7827affe/sphinx-7.2.6-py3-none-any.whl", hash = "sha256:1e09160a40b956dc623c910118fa636da93bd3ca0b9876a7b3df90f07d691560", size = 3207959 }, ] [[package]] @@ -2298,36 +2290,15 @@ wheels = [ [[package]] name = "sphinx-book-theme" -version = "0.0.42" +version = "1.1.4" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "beautifulsoup4" }, - { name = "click" }, - { name = "docutils", version = "0.17.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, - { name = "docutils", version = "0.21.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "pydata-sphinx-theme" }, - { name = "pyyaml" }, { name = "sphinx" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a7/ac/704a3e5bbf0aab437851ffd6a5ff366e8f5bf797d44d7eb6dd05e9d39975/sphinx-book-theme-0.0.42.tar.gz", hash = "sha256:a67d3ead308eedec048d52d0ef0f958795f432464b9db02d6612a5697bcf9e33", size = 56585 } +sdist = { url = "https://files.pythonhosted.org/packages/45/19/d002ed96bdc7738c15847c730e1e88282d738263deac705d5713b4d8fa94/sphinx_book_theme-1.1.4.tar.gz", hash = "sha256:73efe28af871d0a89bd05856d300e61edce0d5b2fbb7984e84454be0fedfe9ed", size = 439188 } wheels = [ - { url = "https://files.pythonhosted.org/packages/8b/da/426f72e1c45f0e1394c21a872d2a610370f3950293b35d2ed0d773284b7f/sphinx_book_theme-0.0.42-py3-none-any.whl", hash = "sha256:ce958d2c6d91573215f0f591bf97c68f722be313e3c0a19983ab571cc642aed3", size = 89459 }, -] - -[[package]] -name = "sphinx-togglebutton" -version = "0.2.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "docutils", version = "0.17.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, - { name = "docutils", version = "0.21.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, - { name = "setuptools" }, - { name = "sphinx" }, - { name = "wheel" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/25/e7/cfe952ad8de462080eaebb41108994d5c822b4911fbb65ecb1ec79d25446/sphinx-togglebutton-0.2.3.tar.gz", hash = "sha256:41cbe2f87459eade8dc5718bb56146e8e113a05fb97459b90472470f0d357b55", size = 5411 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c9/42/deaa3b89f7617cee51df70c1bcecaf885ab3d59302a2b96ce89d85da118a/sphinx_togglebutton-0.2.3-py3-none-any.whl", hash = "sha256:8a3707154b1b3480a7918f189f43b7eee0d34ffa552895af77bb273476b8d5e0", size = 6144 }, + { url = "https://files.pythonhosted.org/packages/51/9e/c41d68be04eef5b6202b468e0f90faf0c469f3a03353f2a218fd78279710/sphinx_book_theme-1.1.4-py3-none-any.whl", hash = "sha256:843b3f5c8684640f4a2d01abd298beb66452d1b2394cd9ef5be5ebd5640ea0e1", size = 433952 }, ] [[package]] @@ -2432,6 +2403,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d9/61/f2b52e107b1fc8944b33ef56bf6ac4ebbe16d91b94d2b87ce013bf63fb84/starlette-0.45.3-py3-none-any.whl", hash = "sha256:dfb6d332576f136ec740296c7e8bb8c8a7125044e7c6da30744718880cdd059d", size = 71507 }, ] +[[package]] +name = "tabulate" +version = "0.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/fe/802052aecb21e3797b8f7902564ab6ea0d60ff8ca23952079064155d1ae1/tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c", size = 81090 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252 }, +] + [[package]] name = "terminado" version = "0.18.1" @@ -2760,15 +2740,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e8/b2/31eec524b53f01cd8343f10a8e429730c52c1849941d1f530f8253b6d934/websockets-15.0-py3-none-any.whl", hash = "sha256:51ffd53c53c4442415b613497a34ba0aa7b99ac07f1e4a62db5dcd640ae6c3c3", size = 169023 }, ] -[[package]] -name = "wheel" -version = "0.45.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/8a/98/2d9906746cdc6a6ef809ae6338005b3f21bb568bea3165cfc6a243fdc25c/wheel-0.45.1.tar.gz", hash = "sha256:661e1abd9198507b1409a20c02106d9670b2576e916d58f520316666abca6729", size = 107545 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/0b/2c/87f3254fd8ffd29e4c02732eee68a83a1d3c346ae39bc6822dcbcb697f2b/wheel-0.45.1-py3-none-any.whl", hash = "sha256:708e7481cc80179af0e556bbf0cc00b8444c7321e2700b8d8580231d13017248", size = 72494 }, -] - [[package]] name = "widgetsnbextension" version = "3.6.10"