Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement automatic reconnection #287

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 74 additions & 35 deletions aiomqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ class Client:
password: The password to authenticate with.
logger: Custom logger instance.
identifier: The client identifier. Generated automatically if ``None``.
reconnect: If ``True``, the client will automatically reconnect to the broker
if the connection is lost. Defaults to ``False``.
queue_type: The class to use for the queue. The default is
``asyncio.Queue``, which stores messages in FIFO order. For LIFO order,
you can use ``asyncio.LifoQueue``; For priority order you can subclass
Expand Down Expand Up @@ -181,6 +183,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
password: str | None = None,
logger: logging.Logger | None = None,
identifier: str | None = None,
reconnect: bool = False,
queue_type: type[asyncio.Queue[Message]] | None = None,
protocol: ProtocolVersion | None = None,
will: Will | None = None,
Expand All @@ -206,6 +209,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
) -> None:
self._hostname = hostname
self._port = port
self._reconnect = reconnect
self._keepalive = keepalive
self._bind_address = bind_address
self._bind_port = bind_port
Expand All @@ -225,7 +229,10 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
self._pending_unsubscribes: dict[int, asyncio.Event] = {}
self._pending_publishes: dict[int, asyncio.Event] = {}
self.pending_calls_threshold: int = 10

# Background tasks
self._misc_task: asyncio.Task[None] | None = None
self._reconnection_task: asyncio.Task[None] | None = None

# Queue that holds incoming messages
if queue_type is None:
Expand Down Expand Up @@ -432,9 +439,17 @@ async def publish( # noqa: PLR0913
**kwargs: Additional keyword arguments to pass to paho-mqtt's publish
method.
"""
info = self._client.publish(
topic, payload, qos, retain, properties, *args, **kwargs
) # [2]
while True:
info = self._client.publish(
topic, payload, qos, retain, properties, *args, **kwargs
) # [2]
if not (info.rc == mqtt.MQTT_ERR_NO_CONN and self._reconnect):
break
while True:
with contextlib.suppress(asyncio.CancelledError):
await self._connected
break
self._connected = asyncio.Future()
# Early out on error
if info.rc != mqtt.MQTT_ERR_SUCCESS:
raise MqttCodeError(info.rc, "Could not publish message")
Expand Down Expand Up @@ -677,43 +692,65 @@ async def _misc_loop(self) -> None:
while self._client.loop_misc() == mqtt.MQTT_ERR_SUCCESS:
await asyncio.sleep(1)

async def _connect(self) -> None:
"""Connect to the broker. Retry indefinitely if self._reconnect is True."""
while True:
try:
try:
loop = asyncio.get_running_loop()
# [3] Run connect() within an executor thread, since it blocks on socket
# connection for up to `keepalive` seconds: https://git.io/Jt5Yc
await loop.run_in_executor(
None,
self._client.connect,
self._hostname,
self._port,
self._keepalive,
self._bind_address,
self._bind_port,
self._clean_start,
self._properties,
)
_set_client_socket_defaults(self._client.socket(), self._socket_options)
# Convert all possible paho-mqtt Client.connect exceptions to our MqttError
# See: https://github.com/eclipse/paho.mqtt.python/blob/v1.5.0/src/paho/mqtt/client.py#L1770
except (OSError, mqtt.WebsocketConnectionError) as exc:
raise MqttError(str(exc)) from None
await self._wait_for(self._connected, timeout=None)
self._logger.info("Successfully connected to the broker.")
break
except MqttError:
# Reset internal state if the connection attempt failed
if self._connected.done():
self._connected = asyncio.Future()
if self._disconnected.done():
self._disconnected = asyncio.Future()
if not self._reconnect:
self._lock.release()
raise
self._logger.warning("Failed to connect. Trying again in 2 seconds...")
await asyncio.sleep(2)

async def _reconnection(self) -> None:
"""Reconnect when the connection is lost."""
while True:
with contextlib.suppress(MqttError):
await self._disconnected
self._logger.warning("Connection lost. Reconnecting...")
self._connected = asyncio.Future()
self._disconnected = asyncio.Future()
await self._connect()

async def __aenter__(self) -> Self:
"""Connect to the broker."""
if self._lock.locked():
msg = "The client context manager is reusable, but not reentrant"
raise MqttReentrantError(msg)
await self._lock.acquire()
try:
loop = asyncio.get_running_loop()
# [3] Run connect() within an executor thread, since it blocks on socket
# connection for up to `keepalive` seconds: https://git.io/Jt5Yc
await loop.run_in_executor(
None,
self._client.connect,
self._hostname,
self._port,
self._keepalive,
self._bind_address,
self._bind_port,
self._clean_start,
self._properties,
)
_set_client_socket_defaults(self._client.socket(), self._socket_options)
# Convert all possible paho-mqtt Client.connect exceptions to our MqttError
# See: https://github.com/eclipse/paho.mqtt.python/blob/v1.5.0/src/paho/mqtt/client.py#L1770
except (OSError, mqtt.WebsocketConnectionError) as exc:
self._lock.release()
raise MqttError(str(exc)) from None
try:
await self._wait_for(self._connected, timeout=None)
except MqttError:
# Reset state if connection attempt times out or CONNACK returns negative
self._lock.release()
self._connected = asyncio.Future()
raise
# Reset `_disconnected` if it's already in completed state after connecting
if self._disconnected.done():
self._disconnected = asyncio.Future()
await self._connect()
# Start the reconnection task
if self._reconnect:
self._reconnection_task = asyncio.create_task(self._reconnection())
return self

async def __aexit__(
Expand All @@ -723,8 +760,10 @@ async def __aexit__(
tb: TracebackType | None,
) -> None:
"""Disconnect from the broker."""
if self._reconnect:
self._reconnection_task.cancel()
# Return early if the client is already disconnected
if self._disconnected.done():
# Return early if the client is already disconnected
if self._lock.locked():
self._lock.release()
if (exc := self._disconnected.exception()) is not None:
Expand Down
Loading