Skip to content

Commit

Permalink
Add verfication for mqtt config (#410)
Browse files Browse the repository at this point in the history
  • Loading branch information
edenhaus authored Jan 29, 2024
1 parent 9f36b73 commit 63d1772
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 8 deletions.
4 changes: 4 additions & 0 deletions deebot_client/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,7 @@ class ApiError(DeebotError):

class MapError(DeebotError):
"""Map error."""


class MqttError(DeebotError):
"""Mqtt error."""
19 changes: 14 additions & 5 deletions deebot_client/mqtt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse

from aiomqtt import Client, Message, MqttError
from aiomqtt import Client, Message, MqttError as AioMqttError
from cachetools import TTLCache

from deebot_client.const import DataType
from deebot_client.exceptions import AuthenticationError, DeebotError
from deebot_client.exceptions import AuthenticationError, MqttError

from .commands import COMMANDS_WITH_MQTT_P2P_HANDLING
from .logging_filter import get_logger
Expand Down Expand Up @@ -76,10 +76,10 @@ def create_config(
default_port = 8883
ssl_ctx = ssl.create_default_context()
case _:
raise DeebotError("Invalid scheme. Expecting mqtt or mqtts")
raise MqttError("Invalid scheme. Expecting mqtt or mqtts")

if not url.hostname:
raise DeebotError("Hostame is required")
raise MqttError("Hostame is required")

hostname = url.hostname
port = url.port or default_port
Expand Down Expand Up @@ -141,6 +141,15 @@ def last_message_received_at(self) -> datetime | None:
"""Return the datetime of the last received message or None."""
return self._last_message_received_at

async def verify_config(self) -> None:
"""Verify config by connecting to the broker."""
try:
async with await self._get_client():
_LOGGER.debug("Connection successfully")
except AioMqttError as ex:
_LOGGER.warning("Cannot connect", exc_info=True)
raise MqttError("Cannot connect") from ex

async def subscribe(self, info: SubscriberInfo) -> Callable[[], None]:
"""Subscribe for messages from given device."""
await self.connect()
Expand Down Expand Up @@ -210,7 +219,7 @@ async def listen() -> None:
finally:
for task in tasks:
task.cancel()
except MqttError:
except AioMqttError:
_LOGGER.warning(
"Connection lost; Reconnecting in %d seconds ...",
RECONNECT_INTERVAL,
Expand Down
37 changes: 34 additions & 3 deletions tests/test_mqtt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
from typing import TYPE_CHECKING, Any
from unittest.mock import DEFAULT, MagicMock, Mock, patch

from aiomqtt import Client, Message
from aiomqtt import Client, Message, MqttError as AioMqttError
from cachetools import TTLCache
import pytest

from deebot_client.commands.json.battery import GetBattery
from deebot_client.commands.json.volume import SetVolume
from deebot_client.const import DataType
from deebot_client.exceptions import AuthenticationError, DeebotError
from deebot_client.exceptions import AuthenticationError, MqttError
from deebot_client.mqtt_client import MqttClient, MqttConfiguration, create_config

from .mqtt_util import subscribe, verify_subscribe
Expand Down Expand Up @@ -387,7 +387,7 @@ def test_config_override_mqtt_url_invalid(
authenticator: Authenticator, override_mqtt_url: str, error_msg: str
) -> None:
"""Test that an invalid mqtt override url will raise a DeebotError."""
with pytest.raises(DeebotError, match=error_msg):
with pytest.raises(MqttError, match=error_msg):
MqttClient(
create_config(
device_id="123",
Expand All @@ -396,3 +396,34 @@ def test_config_override_mqtt_url_invalid(
),
authenticator,
)


async def test_verify_config(authenticator: Authenticator) -> None:
with patch("deebot_client.mqtt_client.Client", autospec=True) as client_mock:
client = MqttClient(
create_config(
device_id="123",
country="IT",
),
authenticator,
)

await client.verify_config()
client_mock.return_value.__aenter__.assert_called()


async def test_verify_config_fails(authenticator: Authenticator) -> None:
with patch("deebot_client.mqtt_client.Client", autospec=True) as client_mock:
client_mock.return_value.__aenter__.side_effect = AioMqttError
client = MqttClient(
create_config(
device_id="123",
country="IT",
),
authenticator,
)

with pytest.raises(MqttError, match="Cannot connect"):
await client.verify_config()

client_mock.return_value.__aenter__.assert_called()

0 comments on commit 63d1772

Please sign in to comment.