Skip to content

Commit

Permalink
Use single MQTT connection for each payload publish (#236)
Browse files Browse the repository at this point in the history
  • Loading branch information
bachya authored Jul 31, 2022
1 parent 70905e6 commit 2ec5c81
Show file tree
Hide file tree
Showing 13 changed files with 351 additions and 289 deletions.
16 changes: 10 additions & 6 deletions ecowitt2mqtt/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from typing import Any

from ecowitt2mqtt.config import Config
from ecowitt2mqtt.const import CONF_VERBOSE, LEGACY_ENV_LOG_LEVEL, LOGGER
from ecowitt2mqtt.const import CONF_VERBOSE, LEGACY_ENV_LOG_LEVEL
from ecowitt2mqtt.helpers.logging import TyperLoggerHandler
from ecowitt2mqtt.server import Server
from ecowitt2mqtt.runtime import Runtime


class Ecowitt: # pylint: disable=too-few-public-methods
Expand All @@ -27,10 +27,14 @@ def __init__(self, params: dict[str, Any]) -> None:
handlers=(TyperLoggerHandler(),),
)

self.config = Config(params)
self.server = Server(self)
self._config = Config(params)
self._runtime = Runtime(self)

@property
def config(self) -> Config:
"""Return the config object."""
return self._config

async def async_start(self) -> None:
"""Start ecowitt2mqtt."""
LOGGER.info("Starting ecowitt2mqtt")
await self.server.async_start()
await self._runtime.async_start()
26 changes: 1 addition & 25 deletions ecowitt2mqtt/helpers/publisher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,18 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from datetime import datetime
import json
from ssl import SSLContext
from typing import TYPE_CHECKING, Any

from asyncio_mqtt import Client

from ecowitt2mqtt.const import LOGGER
from ecowitt2mqtt.errors import EcowittError
from ecowitt2mqtt.helpers.typing import DataValueType

if TYPE_CHECKING:
from ecowitt2mqtt.core import Ecowitt


class PublishError(EcowittError):
"""Define an error related to a failed data publish."""

pass


def generate_mqtt_payload(data: DataValueType) -> bytes:
"""Generate a binary MQTT payload from input data."""
if isinstance(data, dict):
Expand All @@ -50,20 +39,7 @@ def __init__(self, ecowitt: Ecowitt) -> None:
"""Initialize."""
self.ecowitt = ecowitt

@asynccontextmanager
async def async_get_client(self) -> AsyncIterator[Client]:
"""Get an MQTT client."""
async with Client(
self.ecowitt.config.mqtt_broker,
logger=LOGGER,
password=self.ecowitt.config.mqtt_password,
port=self.ecowitt.config.mqtt_port,
tls_context=SSLContext() if self.ecowitt.config.mqtt_tls else None,
username=self.ecowitt.config.mqtt_username,
) as client:
yield client

@abstractmethod
async def async_publish(self, data: dict[str, Any]) -> None:
async def async_publish(self, client: Client, data: dict[str, Any]) -> None:
"""Publish the data."""
raise NotImplementedError()
62 changes: 28 additions & 34 deletions ecowitt2mqtt/helpers/publisher/hass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, TypedDict

from asyncio_mqtt import MqttError
from asyncio_mqtt import Client, MqttError

from ecowitt2mqtt.backports.enum import StrEnum
from ecowitt2mqtt.const import (
Expand Down Expand Up @@ -78,11 +78,7 @@
get_battery_strategy,
)
from ecowitt2mqtt.helpers.device import Device
from ecowitt2mqtt.helpers.publisher import (
MqttPublisher,
PublishError,
generate_mqtt_payload,
)
from ecowitt2mqtt.helpers.publisher import MqttPublisher, generate_mqtt_payload
from ecowitt2mqtt.helpers.typing import DataValueType

if TYPE_CHECKING:
Expand Down Expand Up @@ -470,44 +466,42 @@ def _generate_discovery_payload(

return payload

async def async_publish(self, data: dict[str, DataValueType]) -> None:
async def async_publish(
self, client: Client, data: dict[str, DataValueType]
) -> None:
"""Publish to MQTT."""
processed_data = ProcessedData(self.ecowitt, data)
tasks = []

try:
async with self.async_get_client() as client:
for payload_key, data_point in processed_data.output.items():
discovery_payload = self._generate_discovery_payload(
processed_data.device, payload_key, data_point
)

for topic, payload in (
(discovery_payload.topic, discovery_payload.payload),
(
discovery_payload.payload["availability_topic"],
get_availability_payload(data_point),
),
(discovery_payload.payload["state_topic"], data_point.value),
):
tasks.append(
asyncio.create_task(
client.publish(
topic,
payload=generate_mqtt_payload(payload),
retain=self.ecowitt.config.mqtt_retain,
)
for payload_key, data_point in processed_data.output.items():
discovery_payload = self._generate_discovery_payload(
processed_data.device, payload_key, data_point
)

for topic, payload in (
(discovery_payload.topic, discovery_payload.payload),
(
discovery_payload.payload["availability_topic"],
get_availability_payload(data_point),
),
(discovery_payload.payload["state_topic"], data_point.value),
):
tasks.append(
asyncio.create_task(
client.publish(
topic,
payload=generate_mqtt_payload(payload),
retain=self.ecowitt.config.mqtt_retain,
)
)
)

await asyncio.gather(*tasks)
except MqttError as err:
await asyncio.gather(*tasks)
except MqttError:
for task in tasks:
task.cancel()

raise PublishError(
f"Error while publishing to Home Assistant MQTT Discovery: {err}"
) from err
raise

LOGGER.info("Published to Home Assistant MQTT Discovery")
LOGGER.debug("Published data: %s", processed_data.output)
28 changes: 10 additions & 18 deletions ecowitt2mqtt/helpers/publisher/topic.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,30 @@
"""Define MQTT publishing."""
from __future__ import annotations

from asyncio_mqtt import MqttError
from asyncio_mqtt import Client

from ecowitt2mqtt.const import LOGGER
from ecowitt2mqtt.data import ProcessedData
from ecowitt2mqtt.helpers.publisher import (
MqttPublisher,
PublishError,
generate_mqtt_payload,
)
from ecowitt2mqtt.helpers.publisher import MqttPublisher, generate_mqtt_payload
from ecowitt2mqtt.helpers.typing import DataValueType


class TopicPublisher(MqttPublisher):
"""Define an MQTT publisher that publishes to a topic."""

async def async_publish(self, data: dict[str, DataValueType]) -> None:
async def async_publish(
self, client: Client, data: dict[str, DataValueType]
) -> None:
"""Publish to MQTT."""
if not self.ecowitt.config.raw_data:
processed_data = ProcessedData(self.ecowitt, data)
data = {key: value.value for key, value in processed_data.output.items()}

try:
async with self.async_get_client() as client:
await client.publish(
self.ecowitt.config.mqtt_topic,
payload=generate_mqtt_payload(data),
retain=self.ecowitt.config.mqtt_retain,
)
except MqttError as err:
raise PublishError(
f"Error while publishing to {self.ecowitt.config.mqtt_topic}: {err}"
) from err
await client.publish(
self.ecowitt.config.mqtt_topic,
payload=generate_mqtt_payload(data),
retain=self.ecowitt.config.mqtt_retain,
)

LOGGER.info("Published to %s", self.ecowitt.config.mqtt_topic)
LOGGER.debug("Published data: %s", data)
168 changes: 168 additions & 0 deletions ecowitt2mqtt/runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
"""Define runtime management."""
from __future__ import annotations

import asyncio
import logging
import signal
from ssl import SSLContext
import traceback
from types import FrameType
from typing import TYPE_CHECKING, Any

from asyncio_mqtt import Client, MqttError
from fastapi import FastAPI, Request, Response, status
import uvicorn

from ecowitt2mqtt.const import LOGGER
from ecowitt2mqtt.helpers.publisher.factory import get_publisher

if TYPE_CHECKING:
from ecowitt2mqtt.core import Ecowitt

DEFAULT_HOST = "0.0.0.0"
DEFAULT_MAX_RETRY_INTERVAL = 60

LOG_LEVEL_DEBUG = "debug"
LOG_LEVEL_ERROR = "error"

HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
)


class MyCustomUvicornServer(uvicorn.Server): # type: ignore
"""Define a Uvicorn server that doesn't swallow signals."""

def install_signal_handlers(self) -> None:
"""Don't swallow signals."""
pass


class Runtime:
"""Define the runtime manager."""

def __init__(self, ecowitt: Ecowitt) -> None:
"""Initialize."""
self.ecowitt = ecowitt

app = FastAPI()
app.post(
ecowitt.config.endpoint,
status_code=status.HTTP_204_NO_CONTENT,
response_class=Response,
)(self._async_post_data)
self._server = MyCustomUvicornServer(
config=uvicorn.Config(
app,
host=DEFAULT_HOST,
port=ecowitt.config.port,
log_level="debug" if ecowitt.config.verbose else "info",
)
)

self._latest_payload: dict[str, Any] | None = None
self._new_payload_condition = asyncio.Condition()
self._publisher = get_publisher(ecowitt)
self._runtime_tasks: list[asyncio.Task] = []

# Remove the existing Uvicorn logger handler so that we don't get duplicates:
# https://github.com/encode/uvicorn/issues/1285
uvicorn_logger = logging.getLogger("uvicorn")
uvicorn_logger.removeHandler(uvicorn_logger.handlers[0])

async def _async_create_mqtt_loop(self) -> None:
"""Create the MQTT process loop."""
LOGGER.debug("Starting MQTT process loop")

retry_attempt = 0
while True:
try:
async with Client(
self.ecowitt.config.mqtt_broker,
logger=LOGGER,
password=self.ecowitt.config.mqtt_password,
port=self.ecowitt.config.mqtt_port,
tls_context=SSLContext() if self.ecowitt.config.mqtt_tls else None,
username=self.ecowitt.config.mqtt_username,
) as client:
while True:
async with self._new_payload_condition:
await self._new_payload_condition.wait()
LOGGER.debug("Publishing payload: %s", self._latest_payload)
assert self._latest_payload
await self._publisher.async_publish(
client, self._latest_payload
)
retry_attempt = 0

if self.ecowitt.config.diagnostics:
LOGGER.debug("*** DIAGNOSTICS COLLECTED")
self.stop()
except asyncio.CancelledError:
LOGGER.debug("Stopping MQTT process loop")
raise
except MqttError as err:
LOGGER.error("There was an MQTT error: %s", err)
LOGGER.debug("".join(traceback.format_tb(err.__traceback__)))

retry_attempt += 1
delay = min(retry_attempt**2, DEFAULT_MAX_RETRY_INTERVAL)
LOGGER.info(
"Attempting MQTT reconnection in %s seconds (attempt %s)",
delay,
retry_attempt,
)
await asyncio.sleep(delay)

async def _async_create_server(self) -> None:
"""Create the REST API server."""
LOGGER.debug("Starting REST API server")

try:
await self._server.serve()
except asyncio.CancelledError:
LOGGER.debug("Stopping REST API server")
raise

async def _async_post_data(self, request: Request) -> Response:
"""Define an endpoint for the Ecowitt device to post data to."""
payload = dict(await request.form())
LOGGER.debug("Received data from the Ecowitt device: %s", payload)
async with self._new_payload_condition:
self._latest_payload = payload
self._new_payload_condition.notify_all()

async def async_start(self) -> None:
"""Start the runtime."""
loop = asyncio.get_running_loop()

def handle_exit_signal(sig: int, frame: FrameType | None) -> None:
"""Handle an exit signal."""
if self._server.should_exit and sig == signal.SIGINT:
self._server.force_exit = True
else:
self._server.should_exit = True
self.stop()

try:
for sig in HANDLED_SIGNALS:
loop.add_signal_handler(sig, handle_exit_signal, sig, None)
except NotImplementedError:
# Windows
for sig in HANDLED_SIGNALS:
signal.signal(sig, handle_exit_signal)

for coro_func in self._async_create_mqtt_loop, self._async_create_server:
self._runtime_tasks.append(asyncio.create_task(coro_func()))

try:
await asyncio.gather(*self._runtime_tasks)
except asyncio.CancelledError:
await asyncio.sleep(0.1)
LOGGER.debug("Runtime shutdown complete")

def stop(self) -> None:
"""Stop the REST API server."""
for task in self._runtime_tasks:
task.cancel()
Loading

0 comments on commit 2ec5c81

Please sign in to comment.