From 2ec5c815dbca3fc6d1b4671bf7b0e2f6a066766b Mon Sep 17 00:00:00 2001 From: Aaron Bach Date: Sun, 31 Jul 2022 12:54:25 -0600 Subject: [PATCH] Use single MQTT connection for each payload publish (#236) --- ecowitt2mqtt/core.py | 16 +- ecowitt2mqtt/helpers/publisher/__init__.py | 26 +--- ecowitt2mqtt/helpers/publisher/hass.py | 62 ++++---- ecowitt2mqtt/helpers/publisher/topic.py | 28 ++-- ecowitt2mqtt/runtime.py | 168 +++++++++++++++++++++ ecowitt2mqtt/server.py | 83 ---------- pyproject.toml | 2 +- tests/common.py | 16 -- tests/conftest.py | 20 ++- tests/publisher/test_hass_discovery.py | 70 +++++---- tests/publisher/test_topic_publisher.py | 30 ++-- tests/test_runtime.py | 63 ++++++++ tests/test_server.py | 56 ------- 13 files changed, 351 insertions(+), 289 deletions(-) create mode 100644 ecowitt2mqtt/runtime.py delete mode 100644 ecowitt2mqtt/server.py create mode 100644 tests/test_runtime.py delete mode 100644 tests/test_server.py diff --git a/ecowitt2mqtt/core.py b/ecowitt2mqtt/core.py index 1e083be3..f4c8f9d3 100644 --- a/ecowitt2mqtt/core.py +++ b/ecowitt2mqtt/core.py @@ -6,9 +6,9 @@ from typing import Any from ecowitt2mqtt.config import Config -from ecowitt2mqtt.const import CONF_VERBOSE, LEGACY_ENV_LOG_LEVEL, LOGGER +from ecowitt2mqtt.const import CONF_VERBOSE, LEGACY_ENV_LOG_LEVEL from ecowitt2mqtt.helpers.logging import TyperLoggerHandler -from ecowitt2mqtt.server import Server +from ecowitt2mqtt.runtime import Runtime class Ecowitt: # pylint: disable=too-few-public-methods @@ -27,10 +27,14 @@ def __init__(self, params: dict[str, Any]) -> None: handlers=(TyperLoggerHandler(),), ) - self.config = Config(params) - self.server = Server(self) + self._config = Config(params) + self._runtime = Runtime(self) + + @property + def config(self) -> Config: + """Return the config object.""" + return self._config async def async_start(self) -> None: """Start ecowitt2mqtt.""" - LOGGER.info("Starting ecowitt2mqtt") - await self.server.async_start() + await self._runtime.async_start() diff --git a/ecowitt2mqtt/helpers/publisher/__init__.py b/ecowitt2mqtt/helpers/publisher/__init__.py index 0d162aa1..c7d0766e 100644 --- a/ecowitt2mqtt/helpers/publisher/__init__.py +++ b/ecowitt2mqtt/helpers/publisher/__init__.py @@ -2,29 +2,18 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager from datetime import datetime import json -from ssl import SSLContext from typing import TYPE_CHECKING, Any from asyncio_mqtt import Client -from ecowitt2mqtt.const import LOGGER -from ecowitt2mqtt.errors import EcowittError from ecowitt2mqtt.helpers.typing import DataValueType if TYPE_CHECKING: from ecowitt2mqtt.core import Ecowitt -class PublishError(EcowittError): - """Define an error related to a failed data publish.""" - - pass - - def generate_mqtt_payload(data: DataValueType) -> bytes: """Generate a binary MQTT payload from input data.""" if isinstance(data, dict): @@ -50,20 +39,7 @@ def __init__(self, ecowitt: Ecowitt) -> None: """Initialize.""" self.ecowitt = ecowitt - @asynccontextmanager - async def async_get_client(self) -> AsyncIterator[Client]: - """Get an MQTT client.""" - async with Client( - self.ecowitt.config.mqtt_broker, - logger=LOGGER, - password=self.ecowitt.config.mqtt_password, - port=self.ecowitt.config.mqtt_port, - tls_context=SSLContext() if self.ecowitt.config.mqtt_tls else None, - username=self.ecowitt.config.mqtt_username, - ) as client: - yield client - @abstractmethod - async def async_publish(self, data: dict[str, Any]) -> None: + async def async_publish(self, client: Client, data: dict[str, Any]) -> None: """Publish the data.""" raise NotImplementedError() diff --git a/ecowitt2mqtt/helpers/publisher/hass.py b/ecowitt2mqtt/helpers/publisher/hass.py index 844969e6..5b263b85 100644 --- a/ecowitt2mqtt/helpers/publisher/hass.py +++ b/ecowitt2mqtt/helpers/publisher/hass.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, TypedDict -from asyncio_mqtt import MqttError +from asyncio_mqtt import Client, MqttError from ecowitt2mqtt.backports.enum import StrEnum from ecowitt2mqtt.const import ( @@ -78,11 +78,7 @@ get_battery_strategy, ) from ecowitt2mqtt.helpers.device import Device -from ecowitt2mqtt.helpers.publisher import ( - MqttPublisher, - PublishError, - generate_mqtt_payload, -) +from ecowitt2mqtt.helpers.publisher import MqttPublisher, generate_mqtt_payload from ecowitt2mqtt.helpers.typing import DataValueType if TYPE_CHECKING: @@ -470,44 +466,42 @@ def _generate_discovery_payload( return payload - async def async_publish(self, data: dict[str, DataValueType]) -> None: + async def async_publish( + self, client: Client, data: dict[str, DataValueType] + ) -> None: """Publish to MQTT.""" processed_data = ProcessedData(self.ecowitt, data) tasks = [] try: - async with self.async_get_client() as client: - for payload_key, data_point in processed_data.output.items(): - discovery_payload = self._generate_discovery_payload( - processed_data.device, payload_key, data_point - ) - - for topic, payload in ( - (discovery_payload.topic, discovery_payload.payload), - ( - discovery_payload.payload["availability_topic"], - get_availability_payload(data_point), - ), - (discovery_payload.payload["state_topic"], data_point.value), - ): - tasks.append( - asyncio.create_task( - client.publish( - topic, - payload=generate_mqtt_payload(payload), - retain=self.ecowitt.config.mqtt_retain, - ) + for payload_key, data_point in processed_data.output.items(): + discovery_payload = self._generate_discovery_payload( + processed_data.device, payload_key, data_point + ) + + for topic, payload in ( + (discovery_payload.topic, discovery_payload.payload), + ( + discovery_payload.payload["availability_topic"], + get_availability_payload(data_point), + ), + (discovery_payload.payload["state_topic"], data_point.value), + ): + tasks.append( + asyncio.create_task( + client.publish( + topic, + payload=generate_mqtt_payload(payload), + retain=self.ecowitt.config.mqtt_retain, ) ) + ) - await asyncio.gather(*tasks) - except MqttError as err: + await asyncio.gather(*tasks) + except MqttError: for task in tasks: task.cancel() - - raise PublishError( - f"Error while publishing to Home Assistant MQTT Discovery: {err}" - ) from err + raise LOGGER.info("Published to Home Assistant MQTT Discovery") LOGGER.debug("Published data: %s", processed_data.output) diff --git a/ecowitt2mqtt/helpers/publisher/topic.py b/ecowitt2mqtt/helpers/publisher/topic.py index 9b871b39..5d8d1a1d 100644 --- a/ecowitt2mqtt/helpers/publisher/topic.py +++ b/ecowitt2mqtt/helpers/publisher/topic.py @@ -1,38 +1,30 @@ """Define MQTT publishing.""" from __future__ import annotations -from asyncio_mqtt import MqttError +from asyncio_mqtt import Client from ecowitt2mqtt.const import LOGGER from ecowitt2mqtt.data import ProcessedData -from ecowitt2mqtt.helpers.publisher import ( - MqttPublisher, - PublishError, - generate_mqtt_payload, -) +from ecowitt2mqtt.helpers.publisher import MqttPublisher, generate_mqtt_payload from ecowitt2mqtt.helpers.typing import DataValueType class TopicPublisher(MqttPublisher): """Define an MQTT publisher that publishes to a topic.""" - async def async_publish(self, data: dict[str, DataValueType]) -> None: + async def async_publish( + self, client: Client, data: dict[str, DataValueType] + ) -> None: """Publish to MQTT.""" if not self.ecowitt.config.raw_data: processed_data = ProcessedData(self.ecowitt, data) data = {key: value.value for key, value in processed_data.output.items()} - try: - async with self.async_get_client() as client: - await client.publish( - self.ecowitt.config.mqtt_topic, - payload=generate_mqtt_payload(data), - retain=self.ecowitt.config.mqtt_retain, - ) - except MqttError as err: - raise PublishError( - f"Error while publishing to {self.ecowitt.config.mqtt_topic}: {err}" - ) from err + await client.publish( + self.ecowitt.config.mqtt_topic, + payload=generate_mqtt_payload(data), + retain=self.ecowitt.config.mqtt_retain, + ) LOGGER.info("Published to %s", self.ecowitt.config.mqtt_topic) LOGGER.debug("Published data: %s", data) diff --git a/ecowitt2mqtt/runtime.py b/ecowitt2mqtt/runtime.py new file mode 100644 index 00000000..bc5fbb5d --- /dev/null +++ b/ecowitt2mqtt/runtime.py @@ -0,0 +1,168 @@ +"""Define runtime management.""" +from __future__ import annotations + +import asyncio +import logging +import signal +from ssl import SSLContext +import traceback +from types import FrameType +from typing import TYPE_CHECKING, Any + +from asyncio_mqtt import Client, MqttError +from fastapi import FastAPI, Request, Response, status +import uvicorn + +from ecowitt2mqtt.const import LOGGER +from ecowitt2mqtt.helpers.publisher.factory import get_publisher + +if TYPE_CHECKING: + from ecowitt2mqtt.core import Ecowitt + +DEFAULT_HOST = "0.0.0.0" +DEFAULT_MAX_RETRY_INTERVAL = 60 + +LOG_LEVEL_DEBUG = "debug" +LOG_LEVEL_ERROR = "error" + +HANDLED_SIGNALS = ( + signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. + signal.SIGTERM, # Unix signal 15. Sent by `kill `. +) + + +class MyCustomUvicornServer(uvicorn.Server): # type: ignore + """Define a Uvicorn server that doesn't swallow signals.""" + + def install_signal_handlers(self) -> None: + """Don't swallow signals.""" + pass + + +class Runtime: + """Define the runtime manager.""" + + def __init__(self, ecowitt: Ecowitt) -> None: + """Initialize.""" + self.ecowitt = ecowitt + + app = FastAPI() + app.post( + ecowitt.config.endpoint, + status_code=status.HTTP_204_NO_CONTENT, + response_class=Response, + )(self._async_post_data) + self._server = MyCustomUvicornServer( + config=uvicorn.Config( + app, + host=DEFAULT_HOST, + port=ecowitt.config.port, + log_level="debug" if ecowitt.config.verbose else "info", + ) + ) + + self._latest_payload: dict[str, Any] | None = None + self._new_payload_condition = asyncio.Condition() + self._publisher = get_publisher(ecowitt) + self._runtime_tasks: list[asyncio.Task] = [] + + # Remove the existing Uvicorn logger handler so that we don't get duplicates: + # https://github.com/encode/uvicorn/issues/1285 + uvicorn_logger = logging.getLogger("uvicorn") + uvicorn_logger.removeHandler(uvicorn_logger.handlers[0]) + + async def _async_create_mqtt_loop(self) -> None: + """Create the MQTT process loop.""" + LOGGER.debug("Starting MQTT process loop") + + retry_attempt = 0 + while True: + try: + async with Client( + self.ecowitt.config.mqtt_broker, + logger=LOGGER, + password=self.ecowitt.config.mqtt_password, + port=self.ecowitt.config.mqtt_port, + tls_context=SSLContext() if self.ecowitt.config.mqtt_tls else None, + username=self.ecowitt.config.mqtt_username, + ) as client: + while True: + async with self._new_payload_condition: + await self._new_payload_condition.wait() + LOGGER.debug("Publishing payload: %s", self._latest_payload) + assert self._latest_payload + await self._publisher.async_publish( + client, self._latest_payload + ) + retry_attempt = 0 + + if self.ecowitt.config.diagnostics: + LOGGER.debug("*** DIAGNOSTICS COLLECTED") + self.stop() + except asyncio.CancelledError: + LOGGER.debug("Stopping MQTT process loop") + raise + except MqttError as err: + LOGGER.error("There was an MQTT error: %s", err) + LOGGER.debug("".join(traceback.format_tb(err.__traceback__))) + + retry_attempt += 1 + delay = min(retry_attempt**2, DEFAULT_MAX_RETRY_INTERVAL) + LOGGER.info( + "Attempting MQTT reconnection in %s seconds (attempt %s)", + delay, + retry_attempt, + ) + await asyncio.sleep(delay) + + async def _async_create_server(self) -> None: + """Create the REST API server.""" + LOGGER.debug("Starting REST API server") + + try: + await self._server.serve() + except asyncio.CancelledError: + LOGGER.debug("Stopping REST API server") + raise + + async def _async_post_data(self, request: Request) -> Response: + """Define an endpoint for the Ecowitt device to post data to.""" + payload = dict(await request.form()) + LOGGER.debug("Received data from the Ecowitt device: %s", payload) + async with self._new_payload_condition: + self._latest_payload = payload + self._new_payload_condition.notify_all() + + async def async_start(self) -> None: + """Start the runtime.""" + loop = asyncio.get_running_loop() + + def handle_exit_signal(sig: int, frame: FrameType | None) -> None: + """Handle an exit signal.""" + if self._server.should_exit and sig == signal.SIGINT: + self._server.force_exit = True + else: + self._server.should_exit = True + self.stop() + + try: + for sig in HANDLED_SIGNALS: + loop.add_signal_handler(sig, handle_exit_signal, sig, None) + except NotImplementedError: + # Windows + for sig in HANDLED_SIGNALS: + signal.signal(sig, handle_exit_signal) + + for coro_func in self._async_create_mqtt_loop, self._async_create_server: + self._runtime_tasks.append(asyncio.create_task(coro_func())) + + try: + await asyncio.gather(*self._runtime_tasks) + except asyncio.CancelledError: + await asyncio.sleep(0.1) + LOGGER.debug("Runtime shutdown complete") + + def stop(self) -> None: + """Stop the REST API server.""" + for task in self._runtime_tasks: + task.cancel() diff --git a/ecowitt2mqtt/server.py b/ecowitt2mqtt/server.py deleted file mode 100644 index 51ec994c..00000000 --- a/ecowitt2mqtt/server.py +++ /dev/null @@ -1,83 +0,0 @@ -"""Define a REST API server for Ecowitt devices to interact with.""" -from __future__ import annotations - -import asyncio -import traceback -from typing import TYPE_CHECKING - -from fastapi import FastAPI, Request, Response, status -import uvicorn - -from ecowitt2mqtt.const import LOGGER -from ecowitt2mqtt.helpers.publisher import PublishError -from ecowitt2mqtt.helpers.publisher.factory import get_publisher - -if TYPE_CHECKING: - from ecowitt2mqtt.core import Ecowitt - -DEFAULT_HOST = "0.0.0.0" - -LOG_LEVEL_DEBUG = "debug" -LOG_LEVEL_ERROR = "error" - - -class Server: - """Define the server management object.""" - - def __init__(self, ecowitt: Ecowitt) -> None: - """Initialize.""" - self._startup_task: asyncio.Task | None = None - - self.app = FastAPI() - self.ecowitt = ecowitt - self.publisher = get_publisher(ecowitt) - self.server = uvicorn.Server( - config=uvicorn.Config( - self.app, - host=DEFAULT_HOST, - port=ecowitt.config.port, - log_level="debug" if ecowitt.config.verbose else "error", - ) - ) - - async def _async_post_data(self, request: Request) -> Response: - """Define an endpoint for the Ecowitt device to post data to.""" - payload = await request.form() - LOGGER.debug("Received data from the Ecowitt device: %s", dict(payload)) - - try: - await self.publisher.async_publish(payload) - except PublishError as err: - LOGGER.error("Unable to publish payload: %s", err) - LOGGER.debug("".join(traceback.format_tb(err.__traceback__))) - - if self.ecowitt.config.diagnostics: - LOGGER.debug("*** DIAGNOSTICS COLLECTED") - self.stop() - - async def async_start(self) -> None: - """Start the REST API server.""" - LOGGER.debug( - "Starting REST API server: http://%s:%s%s", - DEFAULT_HOST, - self.ecowitt.config.port, - self.ecowitt.config.endpoint, - ) - - self.app.post( - self.ecowitt.config.endpoint, - status_code=status.HTTP_204_NO_CONTENT, - response_class=Response, - )(self._async_post_data) - - self._startup_task = asyncio.create_task(self.server.serve()) - try: - await self._startup_task - except asyncio.CancelledError: - LOGGER.debug("REST API server shutdown complete") - - def stop(self) -> None: - """Stop the REST API server.""" - if self._startup_task: - self._startup_task.cancel() - self._startup_task = None diff --git a/pyproject.toml b/pyproject.toml index 66731888..2d27053e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools >= 35.0.2", "wheel >= 0.29.0", "poetry>=0.12"] build-backend = "poetry.core.masonry.api" [tool.coverage.report] -exclude_lines = ["raise NotImplementedError", "TYPE_CHECKING"] +exclude_lines = ["TYPE_CHECKING", "NotImplementedError", "handle_exit_signal"] fail_under = 100 [tool.coverage.run] diff --git a/tests/common.py b/tests/common.py index ca5f97f0..82c09163 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,6 +1,4 @@ """Define common test utilities.""" -import asyncio -from contextlib import asynccontextmanager import os from ecowitt2mqtt.const import ( @@ -25,7 +23,6 @@ CONF_VERBOSE, UNIT_SYSTEM_IMPERIAL, ) -from ecowitt2mqtt.core import Ecowitt from ecowitt2mqtt.helpers.calculator.battery import BatteryStrategy TEST_ENDPOINT = "/data/report" @@ -84,19 +81,6 @@ """ -@asynccontextmanager -async def async_run_server(ecowitt: Ecowitt): - """Run ecowitt2mqtt.""" - start_task = asyncio.create_task(ecowitt.async_start()) - await asyncio.sleep(0.1) - try: - yield - finally: - await ecowitt.server.server.shutdown() - start_task.cancel() - await asyncio.sleep(0.1) - - def load_fixture(filename): """Load a fixture.""" path = os.path.join(os.path.dirname(__file__), "fixtures", filename) diff --git a/tests/conftest.py b/tests/conftest.py index b2cefc5a..d7d9f068 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,10 @@ """Define dynamic fixtures.""" from __future__ import annotations +import asyncio import json import tempfile -from unittest.mock import AsyncMock, MagicMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest import pytest_asyncio @@ -78,10 +79,21 @@ def runner_fixture(): @pytest_asyncio.fixture(name="setup_asyncio_mqtt") async def setup_asyncio_mqtt_fixture(ecowitt, mock_asyncio_mqtt_client): """Define a fixture to patch asyncio-mqtt properly.""" - with patch( - "ecowitt2mqtt.helpers.publisher.Client", - ) as mock_client_class: + with patch("ecowitt2mqtt.runtime.Client") as mock_client_class: mock_client_class.return_value.__aenter__.return_value = ( mock_asyncio_mqtt_client ) yield + + +@pytest_asyncio.fixture(name="setup_uvicorn_server") +async def setup_uvicorn_server_fixture(ecowitt): + """Define a fixture to patch Uvicorn properly.""" + start_task = asyncio.create_task(ecowitt.async_start()) + await asyncio.sleep(0.1) + try: + yield + finally: + await ecowitt._runtime._server.shutdown() + start_task.cancel() + await asyncio.sleep(0.1) diff --git a/tests/publisher/test_hass_discovery.py b/tests/publisher/test_hass_discovery.py index 4d011942..d3cc0e3f 100644 --- a/tests/publisher/test_hass_discovery.py +++ b/tests/publisher/test_hass_discovery.py @@ -1,6 +1,6 @@ """Define tests for the Home Assistant MQTT Discovery publisher.""" import logging -from unittest.mock import AsyncMock, call +from unittest.mock import call from asyncio_mqtt import MqttError import pytest @@ -11,7 +11,6 @@ CONF_HASS_ENTITY_ID_PREFIX, ) from ecowitt2mqtt.helpers.calculator.battery import BatteryStrategy -from ecowitt2mqtt.helpers.publisher import PublishError from ecowitt2mqtt.helpers.publisher.factory import get_publisher from ecowitt2mqtt.helpers.publisher.hass import HomeAssistantDiscoveryPublisher @@ -48,7 +47,9 @@ async def test_publish( device_data, ecowitt, mock_asyncio_mqtt_client, setup_asyncio_mqtt ): """Test publishing a payload.""" - await ecowitt.server.publisher.async_publish(device_data) + await ecowitt._runtime._publisher.async_publish( + mock_asyncio_mqtt_client, device_data + ) mock_asyncio_mqtt_client.publish.assert_has_awaits( [ call( @@ -1556,7 +1557,9 @@ async def test_publish_custom_entity_id_prefix( device_data, ecowitt, mock_asyncio_mqtt_client, setup_asyncio_mqtt ): """Test publishing a payload with custom HASS entity ID prefix.""" - await ecowitt.server.publisher.async_publish(device_data) + await ecowitt._runtime._publisher.async_publish( + mock_asyncio_mqtt_client, device_data + ) mock_asyncio_mqtt_client.publish.assert_has_awaits( [ call( @@ -3048,6 +3051,30 @@ async def test_publish_custom_entity_id_prefix( ) +@pytest.mark.asyncio +@pytest.mark.parametrize( + "config,device_data_filename,mqtt_publish_side_effect", + [ + ( + { + **TEST_CONFIG_JSON, + CONF_HASS_DISCOVERY: True, + }, + "payload_gw2000a_2.json", + [None, None, None, MqttError], + ), + ], +) +async def test_publish_error_mqtt( + device_data, ecowitt, mock_asyncio_mqtt_client, setup_asyncio_mqtt +): + """Test handling an asyncio-mqtt error when publishing.""" + with pytest.raises(MqttError): + await ecowitt._runtime._publisher.async_publish( + mock_asyncio_mqtt_client, device_data + ) + + @pytest.mark.asyncio @pytest.mark.parametrize( "config", @@ -3064,7 +3091,9 @@ async def test_publish_numeric_battery_strategy( device_data, ecowitt, mock_asyncio_mqtt_client, setup_asyncio_mqtt ): """Test publishing a payload with numeric battery strategy.""" - await ecowitt.server.publisher.async_publish(device_data) + await ecowitt._runtime._publisher.async_publish( + mock_asyncio_mqtt_client, device_data + ) mock_asyncio_mqtt_client.publish.assert_has_awaits( [ call( @@ -4556,26 +4585,6 @@ async def test_publish_numeric_battery_strategy( ) -@pytest.mark.asyncio -@pytest.mark.parametrize( - "config,device_data_filename,mqtt_publish_side_effect", - [ - ( - { - **TEST_CONFIG_JSON, - CONF_HASS_DISCOVERY: True, - }, - "payload_gw2000a_2.json", - [None, None, None, MqttError], - ), - ], -) -async def test_publish_error_mqtt(device_data, ecowitt, setup_asyncio_mqtt): - """Test handling an asyncio-mqtt error when publishing.""" - with pytest.raises(PublishError): - await ecowitt.server.publisher.async_publish(device_data) - - @pytest.mark.asyncio @pytest.mark.parametrize( "config", @@ -4594,7 +4603,9 @@ async def test_no_entity_description( caplog.set_level(logging.DEBUG) device_data["random"] = "value" - await ecowitt.server.publisher.async_publish(device_data) + await ecowitt._runtime._publisher.async_publish( + mock_asyncio_mqtt_client, device_data + ) mock_asyncio_mqtt_client.publish.assert_has_awaits( [ call( @@ -6100,7 +6111,6 @@ async def test_no_entity_description( ] ) - -# assert any( -# m for m in caplog.messages if 'Missing entity description for "random"' in m -# ) + assert any( + m for m in caplog.messages if 'Missing entity description for "random"' in m + ) diff --git a/tests/publisher/test_topic_publisher.py b/tests/publisher/test_topic_publisher.py index 86719ac2..412bc004 100644 --- a/tests/publisher/test_topic_publisher.py +++ b/tests/publisher/test_topic_publisher.py @@ -1,11 +1,9 @@ """Define tests for the MQTT Topic publisher.""" -from unittest.mock import AsyncMock - from asyncio_mqtt import Client, MqttError import pytest from ecowitt2mqtt.const import CONF_MQTT_RETAIN, CONF_RAW_DATA -from ecowitt2mqtt.helpers.publisher import PublishError, generate_mqtt_payload +from ecowitt2mqtt.helpers.publisher import generate_mqtt_payload from ecowitt2mqtt.helpers.publisher.factory import get_publisher from ecowitt2mqtt.helpers.publisher.topic import TopicPublisher @@ -23,7 +21,9 @@ async def test_publish_processed( device_data, ecowitt, mock_asyncio_mqtt_client, setup_asyncio_mqtt ): """Test publishing a processed payload to an TopicPublisher.""" - await ecowitt.server.publisher.async_publish(device_data) + await ecowitt._runtime._publisher.async_publish( + mock_asyncio_mqtt_client, device_data + ) mock_asyncio_mqtt_client.publish.assert_awaited_with( TEST_MQTT_TOPIC, payload=b'{"runtime": 319206.0, "tempin": 79.5, "humidityin": 31.0, "baromrel": 24.74, "baromabs": 24.74, "temp": 93.2, "humidity": 64.0, "winddir": 139.0, "windspeed": 20.89, "windgust": 1.12, "maxdailygust": 8.05, "solarradiation": 264.61, "uv": 2.0, "rainrate": 0.0, "eventrain": 0.0, "hourlyrain": 0.0, "dailyrain": 0.0, "weeklyrain": 0.0, "monthlyrain": 2.177, "yearlyrain": 4.441, "lightning_num": 13.0, "lightning": 0.6, "lightning_time": "2022-04-20T17:17:17+00:00", "wh65batt": "OFF", "dewpoint": 79.2, "feelslike": 111.1, "frostpoint": 70.3, "frostrisk": "No risk", "heatindex": 111.1, "humidityabs": 0.0, "humidityabsin": 0.0, "safe_exposure_time_skin_type_1": 83.3, "safe_exposure_time_skin_type_2": 100.0, "safe_exposure_time_skin_type_3": 133.3, "safe_exposure_time_skin_type_4": 166.7, "safe_exposure_time_skin_type_5": 266.7, "safe_exposure_time_skin_type_6": 433.3, "simmerindex": 113.9, "simmerzone": "Danger of heatstroke", "solarradiation_lux": 33494.9, "solarradiation_perceived": 90.0, "thermalperception": "Severely high", "windchill": null}', @@ -45,27 +45,23 @@ async def test_publish_raw( device_data, ecowitt, mock_asyncio_mqtt_client, setup_asyncio_mqtt ): """Test publishing a raw payload to an TopicPublisher.""" - await ecowitt.server.publisher.async_publish(device_data) + await ecowitt._runtime._publisher.async_publish( + mock_asyncio_mqtt_client, device_data + ) mock_asyncio_mqtt_client.publish.assert_awaited_with( TEST_MQTT_TOPIC, payload=generate_mqtt_payload(device_data), retain=False ) @pytest.mark.asyncio -@pytest.mark.parametrize("mqtt_publish_side_effect", [MqttError]) -async def test_publish_error_mqtt(device_data, ecowitt, setup_asyncio_mqtt): - """Test handling an asyncio-mqtt error when publishing.""" - with pytest.raises(PublishError): - await ecowitt.server.publisher.async_publish(device_data) - - -@pytest.mark.asyncio -async def test_publish_error_unserializable(device_data, ecowitt, setup_asyncio_mqtt): +async def test_publish_error_unserializable( + device_data, ecowitt, mock_asyncio_mqtt_client, setup_asyncio_mqtt +): """Test handling a serialization error when publishing.""" device_data["Test"] = object() publisher = get_publisher(ecowitt) with pytest.raises(TypeError): - await publisher.async_publish(device_data) + await publisher.async_publish(mock_asyncio_mqtt_client, device_data) @pytest.mark.asyncio @@ -83,7 +79,9 @@ async def test_publish_retain( device_data, ecowitt, mock_asyncio_mqtt_client, setup_asyncio_mqtt ): """Test publishing a retained raw payload to an TopicPublisher.""" - await ecowitt.server.publisher.async_publish(device_data) + await ecowitt._runtime._publisher.async_publish( + mock_asyncio_mqtt_client, device_data + ) mock_asyncio_mqtt_client.publish.assert_awaited_with( TEST_MQTT_TOPIC, payload=generate_mqtt_payload(device_data), retain=True ) diff --git a/tests/test_runtime.py b/tests/test_runtime.py new file mode 100644 index 00000000..9c3bae23 --- /dev/null +++ b/tests/test_runtime.py @@ -0,0 +1,63 @@ +"""Define tests for the API server.""" +from __future__ import annotations + +import asyncio +import logging +import os +import signal +import subprocess +from unittest.mock import AsyncMock + +from aiohttp import ClientSession +from asyncio_mqtt import MqttError +import pytest + +from ecowitt2mqtt.const import CONF_DIAGNOSTICS + +from tests.common import TEST_CONFIG_JSON, TEST_ENDPOINT, TEST_PORT + + +@pytest.mark.asyncio +@pytest.mark.parametrize("config", [{**TEST_CONFIG_JSON, CONF_DIAGNOSTICS: True}]) +async def test_get_diagnostics( + caplog, device_data, ecowitt, setup_asyncio_mqtt, setup_uvicorn_server +): + """Test getting diagnostics.""" + caplog.set_level(logging.DEBUG) + async with ClientSession() as session: + resp = await session.request( + "post", + f"http://0.0.0.0:{TEST_PORT}{TEST_ENDPOINT}", + data=device_data, + ) + assert resp.status == 204 + assert any(m for m in caplog.messages if "DIAGNOSTICS COLLECTED" in m) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("mqtt_publish_side_effect", [AsyncMock(side_effect=MqttError)]) +async def test_publish_failure( + caplog, device_data, ecowitt, setup_asyncio_mqtt, setup_uvicorn_server +): + """Test a failed MQTT publish.""" + async with ClientSession() as session: + await session.request( + "post", + f"http://0.0.0.0:{TEST_PORT}{TEST_ENDPOINT}", + data=device_data, + ) + assert any(m for m in caplog.messages if "There was an MQTT error" in m) + + +@pytest.mark.asyncio +async def test_publish_success( + device_data, ecowitt, setup_asyncio_mqtt, setup_uvicorn_server +): + """Test a successful MQTT publish.""" + async with ClientSession() as session: + resp = await session.request( + "post", + f"http://0.0.0.0:{TEST_PORT}{TEST_ENDPOINT}", + data=device_data, + ) + assert resp.status == 204 diff --git a/tests/test_server.py b/tests/test_server.py deleted file mode 100644 index b82eec0a..00000000 --- a/tests/test_server.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Define tests for the API server.""" -from __future__ import annotations - -import logging -from unittest.mock import AsyncMock - -from aiohttp import ClientSession -from asyncio_mqtt import MqttError -import pytest - -from ecowitt2mqtt.const import CONF_DIAGNOSTICS - -from tests.common import TEST_CONFIG_JSON, TEST_ENDPOINT, TEST_PORT, async_run_server - - -@pytest.mark.asyncio -@pytest.mark.parametrize("config", [{**TEST_CONFIG_JSON, CONF_DIAGNOSTICS: True}]) -async def test_get_diagnostics(caplog, device_data, ecowitt, setup_asyncio_mqtt): - """Test getting diagnostics.""" - caplog.set_level(logging.DEBUG) - async with async_run_server(ecowitt): - async with ClientSession() as session: - resp = await session.request( - "post", - f"http://0.0.0.0:{TEST_PORT}{TEST_ENDPOINT}", - data=device_data, - ) - assert resp.status == 204 - assert any(m for m in caplog.messages if "DIAGNOSTICS COLLECTED" in m) - - -@pytest.mark.asyncio -@pytest.mark.parametrize("mqtt_publish_side_effect", [AsyncMock(side_effect=MqttError)]) -async def test_publish_failure(caplog, device_data, ecowitt, setup_asyncio_mqtt): - """Test a failed MQTT publish.""" - async with async_run_server(ecowitt): - async with ClientSession() as session: - await session.request( - "post", - f"http://0.0.0.0:{TEST_PORT}{TEST_ENDPOINT}", - data=device_data, - ) - assert any(m for m in caplog.messages if "Unable to publish payload" in m) - - -@pytest.mark.asyncio -async def test_publish_success(device_data, ecowitt, setup_asyncio_mqtt): - """Test a successful MQTT publish.""" - async with async_run_server(ecowitt): - async with ClientSession() as session: - resp = await session.request( - "post", - f"http://0.0.0.0:{TEST_PORT}{TEST_ENDPOINT}", - data=device_data, - ) - assert resp.status == 204