Skip to content

Commit

Permalink
Allow re-discovery of mqtt integration config payloads (home-assistan…
Browse files Browse the repository at this point in the history
  • Loading branch information
jbouwh authored Oct 26, 2024
1 parent d8b618f commit d237180
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 35 deletions.
63 changes: 56 additions & 7 deletions homeassistant/components/mqtt/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,22 @@

import asyncio
from collections import deque
from dataclasses import dataclass
import functools
from itertools import chain
import logging
import re
import time
from typing import TYPE_CHECKING, Any

from homeassistant.config_entries import ConfigEntry
from homeassistant.config_entries import (
SOURCE_MQTT,
ConfigEntry,
signal_discovered_config_entry_removed,
)
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM
from homeassistant.core import HassJobType, HomeAssistant, callback
from homeassistant.helpers import discovery_flow
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
Expand Down Expand Up @@ -71,6 +77,14 @@ class MQTTDiscoveryPayload(dict[str, Any]):
discovery_data: DiscoveryInfoType


@dataclass(frozen=True)
class MQTTIntegrationDiscoveryConfig:
"""Class to hold an integration discovery playload."""

integration: str
msg: ReceiveMessage


def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) -> None:
"""Clear entry from already discovered list."""
hass.data[DATA_MQTT].discovery_already_discovered.discard(discovery_hash)
Expand Down Expand Up @@ -191,7 +205,7 @@ async def async_start( # noqa: C901
"""Start MQTT Discovery."""
mqtt_data = hass.data[DATA_MQTT]
platform_setup_lock: dict[str, asyncio.Lock] = {}
integration_discovery_messages: dict[str, int] = {}
integration_discovery_messages: dict[str, MQTTIntegrationDiscoveryConfig] = {}

@callback
def _async_add_component(discovery_payload: MQTTDiscoveryPayload) -> None:
Expand Down Expand Up @@ -364,13 +378,39 @@ def discovery_done(_: Any) -> None:
mqtt_integrations = await async_get_mqtt(hass)
integration_unsubscribe = mqtt_data.integration_unsubscribe

async def _async_handle_config_entry_removed(entry: ConfigEntry) -> None:
"""Handle integration config entry changes."""
for discovery_key in entry.discovery_keys[DOMAIN]:
if (
discovery_key.version != 1
or not isinstance(discovery_key.key, str)
or discovery_key.key not in integration_discovery_messages
):
continue
topic = discovery_key.key
discovery_message = integration_discovery_messages[topic]
del integration_discovery_messages[topic]
_LOGGER.debug("Rediscover service on topic %s", topic)
# Initiate re-discovery
await async_integration_message_received(
discovery_message.integration, discovery_message.msg
)

mqtt_data.discovery_unsubscribe.append(
async_dispatcher_connect(
hass,
signal_discovered_config_entry_removed(DOMAIN),
_async_handle_config_entry_removed,
)
)

async def async_integration_message_received(
integration: str, msg: ReceiveMessage
) -> None:
"""Process the received message."""
if (
msg.topic in integration_discovery_messages
and integration_discovery_messages[msg.topic] == hash(msg.payload)
and integration_discovery_messages[msg.topic].msg.payload == msg.payload
):
_LOGGER.debug(
"Ignoring already processed discovery message for '%s' on topic %s: %s",
Expand All @@ -393,14 +433,23 @@ async def async_integration_message_received(
subscribed_topic=msg.subscribed_topic,
timestamp=msg.timestamp,
)
await hass.config_entries.flow.async_init(
integration, context={"source": DOMAIN}, data=data
discovery_key = discovery_flow.DiscoveryKey(
domain=DOMAIN, key=msg.topic, version=1
)
discovery_flow.async_create_flow(
hass,
integration,
{"source": SOURCE_MQTT},
data,
discovery_key=discovery_key,
)
if msg.payload:
# Update the last discovered config message
integration_discovery_messages[msg.topic] = hash(msg.payload)
integration_discovery_messages[msg.topic] = (
MQTTIntegrationDiscoveryConfig(integration=integration, msg=msg)
)
elif msg.topic in integration_discovery_messages:
# Cleanup hash if discovery payload is empty
# Cleanup cache if discovery payload is empty
del integration_discovery_messages[msg.topic]

integration_unsubscribe.update(
Expand Down
138 changes: 110 additions & 28 deletions tests/components/mqtt/test_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
Platform,
)
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.data_entry_flow import AbortFlow, FlowResult
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
Expand Down Expand Up @@ -63,6 +63,53 @@
)


@pytest.fixture
def mqtt_data_flow_calls() -> list[MqttServiceInfo]:
"""Return list to capture MQTT data data flow calls."""
return []


@pytest.fixture
async def mock_mqtt_flow(
hass: HomeAssistant, mqtt_data_flow_calls: list[MqttServiceInfo]
) -> config_entries.ConfigFlow:
"""Test fixure for mqtt integration flow.
The topic is used as a unique ID.
The component test domain used is: `comp`.
Creates an entry if does not exist.
Updates an entry if it exists, and there is an updated payload.
"""

class TestFlow(config_entries.ConfigFlow):
"""Test flow."""

async def async_step_mqtt(self, discovery_info: MqttServiceInfo) -> FlowResult:
"""Test mqtt step."""
await asyncio.sleep(0)
mqtt_data_flow_calls.append(discovery_info)
# Abort a flow if there is an update for the existing entry
if entry := self.hass.config_entries.async_entry_for_domain_unique_id(
"comp", discovery_info.topic
):
hass.config_entries.async_update_entry(
entry,
data={
"name": discovery_info.topic,
"payload": discovery_info.payload,
},
)
raise AbortFlow("already_configured")
await self.async_set_unique_id(discovery_info.topic)
return self.async_create_entry(
title="Test",
data={"name": discovery_info.topic, "payload": discovery_info.payload},
)

return TestFlow


@pytest.mark.parametrize(
"mqtt_config_entry_data",
[{mqtt.CONF_BROKER: "mock-broker", mqtt.CONF_DISCOVERY: False}],
Expand Down Expand Up @@ -1518,20 +1565,14 @@ async def test_mqtt_discovery_flow_starts_once(
hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient,
caplog: pytest.LogCaptureFixture,
mock_mqtt_flow: config_entries.ConfigFlow,
mqtt_data_flow_calls: list[MqttServiceInfo],
) -> None:
"""Check MQTT integration discovery starts a flow once."""

flow_calls: list[MqttServiceInfo] = []

class TestFlow(config_entries.ConfigFlow):
"""Test flow."""

async def async_step_mqtt(self, discovery_info: MqttServiceInfo) -> FlowResult:
"""Test mqtt step."""
await asyncio.sleep(0)
flow_calls.append(discovery_info)
return self.async_create_entry(title="Test", data={})
"""Check MQTT integration discovery starts a flow once.
A flow should be started once after discovery,
and after an entry was removed, to trigger re-discovery.
"""
mock_integration(
hass, MockModule(domain="comp", async_setup_entry=AsyncMock(return_value=True))
)
Expand All @@ -1552,7 +1593,7 @@ def wait_birth(msg: ReceiveMessage) -> None:
"homeassistant.components.mqtt.discovery.async_get_mqtt",
return_value={"comp": ["comp/discovery/#"]},
),
mock_config_flow("comp", TestFlow),
mock_config_flow("comp", mock_mqtt_flow),
):
assert await hass.config_entries.async_setup(entry.entry_id)
await mqtt.async_subscribe(hass, "homeassistant/status", wait_birth)
Expand All @@ -1561,41 +1602,82 @@ def wait_birth(msg: ReceiveMessage) -> None:

assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock)

# Test the initial flow
async_fire_mqtt_message(hass, "comp/discovery/bla/config1", "initial message")
await hass.async_block_till_done(wait_background_tasks=True)
assert len(flow_calls) == 1
assert flow_calls[0].topic == "comp/discovery/bla/config1"
assert flow_calls[0].payload == "initial message"
assert len(mqtt_data_flow_calls) == 1
assert mqtt_data_flow_calls[0].topic == "comp/discovery/bla/config1"
assert mqtt_data_flow_calls[0].payload == "initial message"

# Test we can ignore updates if they are the same
with caplog.at_level(logging.DEBUG):
async_fire_mqtt_message(
hass, "comp/discovery/bla/config1", "initial message"
)
await hass.async_block_till_done(wait_background_tasks=True)
assert "Ignoring already processed discovery message" in caplog.text
assert len(flow_calls) == 1
assert len(mqtt_data_flow_calls) == 1

# Test we can apply updates
async_fire_mqtt_message(hass, "comp/discovery/bla/config1", "update message")
await hass.async_block_till_done(wait_background_tasks=True)

assert len(mqtt_data_flow_calls) == 2
assert mqtt_data_flow_calls[1].topic == "comp/discovery/bla/config1"
assert mqtt_data_flow_calls[1].payload == "update message"

# Test we set up multiple entries
async_fire_mqtt_message(hass, "comp/discovery/bla/config2", "initial message")
await hass.async_block_till_done(wait_background_tasks=True)

assert len(flow_calls) == 2
assert flow_calls[1].topic == "comp/discovery/bla/config2"
assert flow_calls[1].payload == "initial message"
assert len(mqtt_data_flow_calls) == 3
assert mqtt_data_flow_calls[2].topic == "comp/discovery/bla/config2"
assert mqtt_data_flow_calls[2].payload == "initial message"

# Test we update multiple entries
async_fire_mqtt_message(hass, "comp/discovery/bla/config2", "update message")
await hass.async_block_till_done(wait_background_tasks=True)

assert len(flow_calls) == 3
assert flow_calls[2].topic == "comp/discovery/bla/config2"
assert flow_calls[2].payload == "update message"
assert len(mqtt_data_flow_calls) == 4
assert mqtt_data_flow_calls[3].topic == "comp/discovery/bla/config2"
assert mqtt_data_flow_calls[3].payload == "update message"

# An empty message triggers a flow to allow cleanup
# Test an empty message triggers a flow to allow cleanup (if needed)
async_fire_mqtt_message(hass, "comp/discovery/bla/config2", "")
await hass.async_block_till_done(wait_background_tasks=True)

assert len(flow_calls) == 4
assert flow_calls[3].topic == "comp/discovery/bla/config2"
assert flow_calls[3].payload == ""
assert len(mqtt_data_flow_calls) == 5
assert mqtt_data_flow_calls[4].topic == "comp/discovery/bla/config2"
assert mqtt_data_flow_calls[4].payload == ""

# Cleanup the the second entry
assert (
entry := hass.config_entries.async_entry_for_domain_unique_id(
"comp", "comp/discovery/bla/config2"
)
) is not None
await hass.config_entries.async_remove(entry.entry_id)
assert len(hass.config_entries.async_entries(domain="comp")) == 1

# Remove remaining entry1 and assert this triggers an
# automatic re-discovery flow with latest config
assert (
entry := hass.config_entries.async_entry_for_domain_unique_id(
"comp", "comp/discovery/bla/config1"
)
) is not None
assert entry.unique_id == "comp/discovery/bla/config1"
await hass.config_entries.async_remove(entry.entry_id)
assert len(hass.config_entries.async_entries(domain="comp")) == 0

# Wait for re-discovery flow to complete
await hass.async_block_till_done(wait_background_tasks=True)
assert len(mqtt_data_flow_calls) == 6
assert mqtt_data_flow_calls[5].topic == "comp/discovery/bla/config1"
assert mqtt_data_flow_calls[5].payload == "update message"

# Re-discovery triggered the config flow
assert len(hass.config_entries.async_entries(domain="comp")) == 1

assert not mqtt_client_mock.unsubscribe.called

Expand Down

0 comments on commit d237180

Please sign in to comment.