Skip to content

Commit

Permalink
misc: pre-commit / uv.lock
Browse files Browse the repository at this point in the history
Explicitly declare protocol implementations
  • Loading branch information
unkcpz committed Feb 27, 2025
1 parent b1e4e1e commit 14c6348
Show file tree
Hide file tree
Showing 13 changed files with 148 additions and 155 deletions.
1 change: 1 addition & 0 deletions src/plumpy/broadcast_filter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
# type: ignore
import re
import typing

Expand Down
8 changes: 4 additions & 4 deletions src/plumpy/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import Any, Hashable, Optional, Protocol, Union
from typing import Any, Hashable, Optional, Protocol, Union, runtime_checkable

from plumpy import loaders
from plumpy.message import MessageType
Expand All @@ -12,6 +12,7 @@
ProcessStatus = Any


@runtime_checkable
class ProcessController(Protocol):
"""
Control processes using coroutines that will send messages and wait
Expand All @@ -26,7 +27,7 @@ def get_status(self, pid: 'PID_TYPE') -> ProcessStatus:
"""
...

def pause_process(self, pid: 'PID_TYPE', msg: str | None = None) -> ProcessResult:
def pause_process(self, pid: 'PID_TYPE', msg_text: str | None = None) -> Any:
"""
Pause the process
Expand All @@ -52,8 +53,7 @@ def play_process(self, pid: 'PID_TYPE') -> ProcessResult:
...

def play_all(self) -> None:
"""Play all processes that are subscribed to the same coordinator
"""
"""Play all processes that are subscribed to the same coordinator"""

def kill_process(self, pid: 'PID_TYPE', msg_text: str | None = None) -> Any:
"""Kill the process
Expand Down
9 changes: 7 additions & 2 deletions src/plumpy/coordinator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Hashable, Protocol
from re import Pattern
from typing import TYPE_CHECKING, Any, Callable, Hashable, Protocol, runtime_checkable

if TYPE_CHECKING:
ID_TYPE = Hashable
Receiver = Callable[..., Any]


@runtime_checkable
class Coordinator(Protocol):
def hook_rpc_receiver(
self,
Expand Down Expand Up @@ -36,7 +37,11 @@ def unhook_broadcast_receiver(self, identifier: 'ID_TYPE | None') -> None: ...

def unhook_task_receiver(self, identifier: 'ID_TYPE') -> None: ...

def rpc_send(self, recipient_id: Hashable, msg: Any,) -> Any: ...
def rpc_send(
self,
recipient_id: Hashable,
msg: Any,
) -> Any: ...

def broadcast_send(
self,
Expand Down
1 change: 0 additions & 1 deletion src/plumpy/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast

from plumpy.coordinator import Coordinator
from plumpy.exceptions import PersistenceError, TaskRejectedError

from . import loaders, persistence
Expand Down
35 changes: 14 additions & 21 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@
cast,
)

import kiwipy

from plumpy.broadcast_filter import BroadcastFilter
from plumpy.broadcast_filter import BroadcastFilter # type: ignore
from plumpy.coordinator import Coordinator

try:
Expand Down Expand Up @@ -944,19 +942,16 @@ def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> Non
# region Communication

def message_receive(self, msg: MessageType) -> Any:
"""
Coroutine called when the process receives a message from the communicator
"""Coroutine called when the process receives a message from the communicator
:param _comm: the communicator that sent the message
:param msg: the message
:return: the outcome of processing the message, the return value will be sent back as a response to the sender
"""
# self.logger.debug(
# "Process<%s>: received RPC message with communicator '%s': %r",
# self.pid,
# _comm,
# msg,
# )
self.logger.debug(
'Process<%s>: received RPC message: %r',
self.pid,
msg,
)

intent = msg[message.INTENT_KEY]

Expand All @@ -977,19 +972,17 @@ def message_receive(self, msg: MessageType) -> Any:
def broadcast_receive(
self, msg: MessageType, sender: Any, subject: Any, correlation_id: Any
) -> Optional[concurrent.futures.Future]:
"""
Coroutine called when the process receives a message from the communicator
"""Coroutine called when the process receives a message from the communicator
:param msg: the message
"""
self.logger.debug(
"Process<%s>: received broadcast message '%s': %r",
self.pid,
subject,
msg,
)

# self.logger.debug(
# "Process<%s>: received broadcast message '%s' with communicator '%s': %r",
# self.pid,
# subject,
# _comm,
# msg,
# )
# If we get a message we recognise then action it, otherwise ignore
fn = None
if subject == message.Intent.PLAY:
Expand Down
4 changes: 3 additions & 1 deletion src/plumpy/rmq/communications.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def converted(communicator: kiwipy.Communicator, *args: Any, **kwargs: Any) -> k
CommT = TypeVar('CommT', bound=kiwipy.Communicator)


def wrap_communicator(communicator: CommT, loop: Optional[asyncio.AbstractEventLoop] = None) -> 'LoopCommunicator[CommT]':
def wrap_communicator(
communicator: CommT, loop: Optional[asyncio.AbstractEventLoop] = None
) -> 'LoopCommunicator[CommT]':
"""
Wrap a communicator such that all callbacks made to any subscribers are scheduled on the
given event loop.
Expand Down
8 changes: 4 additions & 4 deletions src/plumpy/rmq/process_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import kiwipy

from plumpy import loaders
from plumpy.controller import ProcessController
from plumpy.coordinator import Coordinator
from plumpy.message import (
Intent,
Expand All @@ -29,8 +30,7 @@
ProcessStatus = Any


# FIXME: the class not fit typing of ProcessController protocol
class RemoteProcessController:
class RemoteProcessController(ProcessController):
"""
Control remote processes using coroutines that will send messages and wait
(in a non-blocking way) for their response
Expand Down Expand Up @@ -190,7 +190,7 @@ async def execute_process(
return result


class RemoteProcessThreadController:
class RemoteProcessThreadController(ProcessController):
"""
A class that can be used to control and launch remote processes
"""
Expand All @@ -212,7 +212,7 @@ def get_status(self, pid: 'PID_TYPE') -> kiwipy.Future:
"""
return self._coordinator.rpc_send(pid, MessageBuilder.status())

def pause_process(self, pid: 'PID_TYPE', msg_text: str | None = None) -> kiwipy.Future:
def pause_process(self, pid: 'PID_TYPE', msg_text: str | None = None) -> Any:
"""Pause the process
:param pid: the pid of the process to pause
Expand Down
3 changes: 2 additions & 1 deletion tests/rmq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import kiwipy
import concurrent.futures

from plumpy.coordinator import Coordinator
from plumpy.exceptions import CoordinatorConnectionError

if TYPE_CHECKING:
Expand All @@ -15,7 +16,7 @@


@final
class RmqCoordinator(Generic[U]):
class RmqCoordinator(Coordinator, Generic[U]):
def __init__(self, comm: U):
self._comm = comm

Expand Down
7 changes: 7 additions & 0 deletions tests/rmq/test_coordinator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
from plumpy.coordinator import Coordinator
from . import RmqCoordinator


def test_mock_coordinator():
assert isinstance(RmqCoordinator, Coordinator)
5 changes: 5 additions & 0 deletions tests/rmq/test_process_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from kiwipy import rmq

import plumpy
from plumpy.controller import ProcessController
from plumpy.rmq import process_control

from . import RmqCoordinator
Expand Down Expand Up @@ -42,6 +43,10 @@ def async_controller(_coordinator):
def sync_controller(_coordinator):
yield process_control.RemoteProcessThreadController(_coordinator)

def test_remote_process_controller(sync_controller, async_controller):
assert isinstance(sync_controller, ProcessController)
assert isinstance(async_controller, ProcessController)


class TestRemoteProcessController:
@pytest.mark.asyncio
Expand Down
7 changes: 7 additions & 0 deletions tests/test_coordinator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
from plumpy.coordinator import Coordinator
from .utils import MockCoordinator


def test_mock_coordinator():
assert isinstance(MockCoordinator, Coordinator)
6 changes: 4 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
"""Utilities for tests"""
from __future__ import annotations

import asyncio
import collections
Expand All @@ -12,6 +13,7 @@

import plumpy
from plumpy import persistence, process_states, processes, utils
from plumpy.coordinator import Coordinator
from plumpy.exceptions import CoordinatorConnectionError
from plumpy.message import MessageBuilder
from plumpy.rmq import TaskRejected
Expand All @@ -24,7 +26,7 @@
Snapshot = collections.namedtuple('Snapshot', ['state', 'bundle', 'outputs'])


class MockCoordinator:
class MockCoordinator(Coordinator):
def __init__(self):
self._task_receivers = {}
self._broadcast_receivers = {}
Expand Down Expand Up @@ -105,7 +107,7 @@ def hook_broadcast_receiver(
self._broadcast_receivers[identifier] = receiver
return identifier

def unhook_broadcast_receiver(self, identifier: 'ID_TYPE | None') -> None:
def unhook_broadcast_receiver(self, identifier: 'ID_TYPE | None') -> None:
self._ensure_open()
try:
del self._broadcast_receivers[identifier]
Expand Down
Loading

0 comments on commit 14c6348

Please sign in to comment.