Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve message routing #304

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aiomqtt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
from .exceptions import MqttCodeError, MqttError, MqttReentrantError
from .message import Message
from .router import Router
from .topic import Topic, TopicLike, Wildcard, WildcardLike

# These are placeholders that are managed by poetry-dynamic-versioning
Expand All @@ -19,6 +20,7 @@
"__version_tuple__",
"Client",
"Message",
"Router",
"ProtocolVersion",
"ProxySettings",
"TLSParameters",
Expand Down
16 changes: 16 additions & 0 deletions aiomqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from .exceptions import MqttCodeError, MqttConnectError, MqttError, MqttReentrantError
from .message import Message
from .router import Router
from .types import (
P,
PayloadType,
Expand Down Expand Up @@ -134,6 +135,7 @@ class Client:
password: The password to authenticate with.
logger: Custom logger instance.
identifier: The client identifier. Generated automatically if ``None``.
routers: A list of routers to route messages to.
queue_type: The class to use for the queue. The default is
``asyncio.Queue``, which stores messages in FIFO order. For LIFO order,
you can use ``asyncio.LifoQueue``; For priority order you can subclass
Expand Down Expand Up @@ -186,6 +188,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
password: str | None = None,
logger: logging.Logger | None = None,
identifier: str | None = None,
routers: list[Router] | None = None,
queue_type: type[asyncio.Queue[Message]] | None = None,
protocol: ProtocolVersion | None = None,
will: Will | None = None,
Expand Down Expand Up @@ -250,6 +253,11 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
if protocol is None:
protocol = ProtocolVersion.V311

# List of routers with message handlers
if routers is None:
routers = []
self._routers = routers

# Create the underlying paho-mqtt client instance
self._client: mqtt.Client = mqtt.Client(
callback_api_version=CallbackAPIVersion.VERSION1,
Expand Down Expand Up @@ -453,6 +461,14 @@ async def publish( # noqa: PLR0913
# Wait for confirmation
await self._wait_for(confirmation.wait(), timeout=timeout)

async def route(self, message: Message) -> None:
"""Route a message to the appropriate handler."""
for router in self._routers:
for wildcard, handler in router._handlers.items():
with contextlib.suppress(ValueError):
# If we get a ValueError, we know that the topic doesn't match
await handler(message, self, *message.topic.extract(wildcard))

async def _messages(self) -> AsyncGenerator[Message, None]:
"""Async generator that yields messages from the underlying message queue."""
while True:
Expand Down
11 changes: 11 additions & 0 deletions aiomqtt/router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
class Router:
def __init__(self) -> None:
self._handlers = {}

def match(self, *args: str):
"""Add a new handler with one or multiple wildcards to the router."""
def decorator(func):
for wildcard in args:
self._handlers[wildcard] = func
return func
return decorator
43 changes: 28 additions & 15 deletions aiomqtt/topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,21 @@ def matches(self, wildcard: WildcardLike) -> bool:
Returns:
True if the topic matches the wildcard, False otherwise.
"""
try:
self.extract(wildcard)
return True
except ValueError:
return False

def extract(self, wildcard: WildcardLike) -> list[str]:
"""Extract the wildcard values from the topic.

Args:
wildcard: The wildcard to match against.

Returns:
A list of wildcard values extracted from the topic.
"""
if not isinstance(wildcard, Wildcard):
wildcard = Wildcard(wildcard)
# Split topics into levels to compare them one by one
Expand All @@ -98,21 +113,19 @@ def matches(self, wildcard: WildcardLike) -> bool:
# Shared subscriptions use the topic structure: $share/<group_id>/<topic>
wildcard_levels = wildcard_levels[2:]

def recurse(tl: list[str], wl: list[str]) -> bool:
"""Recursively match topic levels with wildcard levels."""
if not tl:
if not wl or wl[0] == "#":
return True
return False
if not wl:
return False
if wl[0] == "#":
return True
if tl[0] == wl[0] or wl[0] == "+":
return recurse(tl[1:], wl[1:])
return False

return recurse(topic_levels, wildcard_levels)
# Extract wildcard values from the topic
arguments = []
for index, level in enumerate(wildcard_levels):
if level == "#":
return arguments + topic_levels[index:]
if len(topic_levels) == index:
raise ValueError("Topic does not match wildcard")
if level != "+" and level != topic_levels[index]:
raise ValueError("Topic does not match wildcard")
arguments.append(topic_levels[index])
if len(topic_levels) > index + 1:
raise ValueError("Topic does not match wildcard")
return arguments


TopicLike: TypeAlias = "str | Topic"
Loading