Skip to content

Entity recomputation and add/remove at runtime #517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: dev
Choose a base branch
from
61 changes: 61 additions & 0 deletions tests/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from zigpy.zcl import ClusterType
from zigpy.zcl.clusters import general
from zigpy.zcl.clusters.general import Ota, PowerConfiguration
from zigpy.zcl.clusters.lighting import Color
from zigpy.zcl.foundation import Status, WriteAttributesResponse
import zigpy.zdo.types as zdo_t

Expand Down Expand Up @@ -49,6 +50,8 @@
from zha.exceptions import ZHAException
from zha.zigbee.device import (
ClusterBinding,
DeviceEntityAddedEvent,
DeviceEntityRemovedEvent,
DeviceFirmwareInfoUpdatedEvent,
ZHAEvent,
get_device_automation_triggers,
Expand Down Expand Up @@ -1201,3 +1204,61 @@ async def test_symfonisk_events(
)
)
]


async def test_entity_recomputation(zha_gateway: Gateway) -> None:
"""Test entity recomputation."""
zigpy_dev = await zigpy_device_from_json(
zha_gateway.application_controller,
"tests/data/devices/ikea-of-sweden-tradfri-bulb-gu10-ws-400lm.json",
)
zha_device = await join_zigpy_device(zha_gateway, zigpy_dev)

event_listener = mock.Mock()
zha_device.on_all_events(event_listener)

entities1 = set(zha_device.platform_entities.values())

# We lose track of the color temperature
zha_device._zigpy_device.endpoints[1].light_color.add_unsupported_attribute(
Color.AttributeDefs.start_up_color_temperature.id
)
await zha_device.recompute_entities()

entities2 = set(zha_device.platform_entities.values())
assert entities2 - entities1 == set()
assert len(entities1 - entities2) == 1
assert (
list(entities1 - entities2)[0].unique_id
== "68:0a:e2:ff:fe:8f:fa:33-1-768-start_up_color_temperature"
)
assert event_listener.mock_calls == [
call(
DeviceEntityRemovedEvent(
unique_id="68:0a:e2:ff:fe:8f:fa:33-1-768-start_up_color_temperature"
)
)
]

event_listener.reset_mock()

# We add it back
zha_device._zigpy_device.endpoints[1].light_color.remove_unsupported_attribute(
Color.AttributeDefs.start_up_color_temperature.id
)
await zha_device.recompute_entities()

entities3 = set(zha_device.platform_entities.values())
assert (
list(entities3 - entities2)[0].unique_id
== "68:0a:e2:ff:fe:8f:fa:33-1-768-start_up_color_temperature"
)
assert {e.unique_id for e in entities1} == {e.unique_id for e in entities3}

assert event_listener.mock_calls == [
call(
DeviceEntityAddedEvent(
unique_id="68:0a:e2:ff:fe:8f:fa:33-1-768-start_up_color_temperature"
)
)
]
7 changes: 5 additions & 2 deletions tests/test_discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,9 @@ async def test_devices_from_files(
await zha_gateway.async_block_till_done(wait_background_tasks=True)
assert zha_device is not None

# Ensure entity recomputation is idempotent
await zha_device.recompute_entities()

unique_id_collisions = defaultdict(list)
for entity in zha_device.platform_entities.values():
unique_id_collisions[entity.unique_id].append(entity)
Expand Down Expand Up @@ -841,8 +844,6 @@ async def test_devices_from_files(

unique_id_migrations[key] = entity

await zha_device.on_remove()

# XXX: We re-serialize the JSON because integer enum types are converted when
# serializing but will not compare properly otherwise
loaded_device_data = json.loads(
Expand Down Expand Up @@ -871,3 +872,5 @@ async def test_devices_from_files(
tsn=None,
)
]

await zha_device.on_remove()
2 changes: 2 additions & 0 deletions zha/application/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ def pretty_name(self) -> str:
ZHA_CLUSTER_HANDLER_READS_PER_REQ = 5
ZHA_EVENT = "zha_event"
ZHA_DEVICE_UPDATED_EVENT = "zha_device_updated_event"
ZHA_DEVICE_ENTITY_ADDED_EVENT = "zha_device_entity_added_event"
ZHA_DEVICE_ENTITY_REMOVED_EVENT = "zha_device_entity_removed_event"
ZHA_GW_MSG = "zha_gateway_message"
ZHA_GW_MSG_DEVICE_FULL_INIT = "device_fully_initialized"
ZHA_GW_MSG_DEVICE_INFO = "device_info"
Expand Down
2 changes: 1 addition & 1 deletion zha/application/platforms/button/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class IdentifyButton(Button):
def is_supported_in_list(self, entities: list[BaseEntity]) -> bool:
"""Check if this button is supported given the list of entities."""
cls = type(self)
return not any(type(entity) is cls for entity in entities)
return not any(type(entity) is cls for entity in entities if entity is not self)


class WriteAttributeButton(PlatformEntity):
Expand Down
2 changes: 1 addition & 1 deletion zha/application/platforms/sensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1641,7 +1641,7 @@ def _is_supported(self) -> bool:
def is_supported_in_list(self, entities: list[BaseEntity]) -> bool:
"""Check if the sensor is supported given the list of entities."""
cls = type(self)
return not any(type(entity) is cls for entity in entities)
return not any(type(entity) is cls for entity in entities if entity is not self)

@property
def state(self) -> dict:
Expand Down
155 changes: 123 additions & 32 deletions zha/zigbee/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import asyncio
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator
from collections.abc import Callable, Iterable, Iterator, Sequence
import copy
import dataclasses
from dataclasses import dataclass
Expand Down Expand Up @@ -61,6 +61,8 @@
UNKNOWN_MODEL,
ZHA_CLUSTER_HANDLER_CFG_DONE,
ZHA_CLUSTER_HANDLER_MSG,
ZHA_DEVICE_ENTITY_ADDED_EVENT,
ZHA_DEVICE_ENTITY_REMOVED_EVENT,
ZHA_DEVICE_UPDATED_EVENT,
ZHA_EVENT,
)
Expand Down Expand Up @@ -166,6 +168,27 @@ class DeviceFirmwareInfoUpdatedEvent:
new_firmware_version: str | None


@dataclass(kw_only=True, frozen=True)
class DeviceEntityAddedEvent:
"""Event generated when a new entity is added to a device."""

event_type: Final[str] = ZHA_DEVICE_ENTITY_ADDED_EVENT
event: Final[str] = ZHA_DEVICE_ENTITY_ADDED_EVENT

# TODO: allow all entity information to be serialized and include it here
unique_id: str


@dataclass(kw_only=True, frozen=True)
class DeviceEntityRemovedEvent:
"""Event generated when a new entity is added to a device."""

event_type: Final[str] = ZHA_DEVICE_ENTITY_REMOVED_EVENT
event: Final[str] = ZHA_DEVICE_ENTITY_REMOVED_EVENT

unique_id: str


@dataclass(kw_only=True, frozen=True)
class ClusterHandlerConfigurationComplete:
"""Event generated when all cluster handlers are configured."""
Expand Down Expand Up @@ -946,53 +969,121 @@ def _discover_new_entities(self) -> None:
entity.on_add()
self._pending_entities.append(entity)

async def async_initialize(self, from_cache: bool = False) -> None:
"""Initialize cluster handlers."""
self.debug("started initialization")
def _add_entity(self, entity: PlatformEntity) -> None:
"""Add an entity to the device."""
key = (entity.PLATFORM, entity.unique_id)

self._discover_new_entities()
if key in self._platform_entities:
raise ValueError(
f"Cannot add entity {entity!r}, unique ID already taken by {self._platform_entities[key]!r}"
)

await self._zdo_handler.async_initialize(from_cache)
self._zdo_handler.debug("'async_initialize' stage succeeded")
_LOGGER.debug("Discovered new entity %s", entity)

# We intentionally do not use `gather` here! This is so that if, for example,
# three `device.async_initialize()`s are spawned, only three concurrent requests
# will ever be in flight at once. Startup concurrency is managed at the device
# level.
for endpoint in self._endpoints.values():
try:
await endpoint.async_initialize(from_cache)
except Exception: # pylint: disable=broad-exception-caught
self.debug("Failed to initialize endpoint", exc_info=True)
# `entity.on_add()` is assumed to have been called already
self._platform_entities[key] = entity
self.emit(
DeviceEntityAddedEvent.event_type,
DeviceEntityAddedEvent(
unique_id=entity.unique_id,
),
)

async def _remove_entity(
self, entity: BaseEntity, *, emit_event: bool = True
) -> None:
"""Remove an entity from the device."""
key = (entity.PLATFORM, entity.unique_id)

if key not in self._platform_entities:
raise ValueError(f"Cannot remove entity {entity!r}, unique ID not found")

# Compute the final entities
await entity.on_remove()
del self._platform_entities[key]

if emit_event:
self.emit(
DeviceEntityRemovedEvent.event_type,
DeviceEntityRemovedEvent(
unique_id=entity.unique_id,
),
)

async def _add_pending_entities(self) -> None:
"""Add pending entities to the device."""
all_entities = dict(self._platform_entities)
new_entities: dict[tuple[Platform, str], PlatformEntity] = {}

for entity in self._pending_entities:
entity.recompute_capabilities()

# Ignore unsupported entities
if not entity.is_supported() or not entity.is_supported_in_list(
new_entities.values()
all_entities.values()
):
await entity.on_remove()
continue

key = (entity.PLATFORM, entity.unique_id)

# Ignore entities that already exist
if key in new_entities:
if key in all_entities:
await entity.on_remove()
continue

all_entities[key] = entity
new_entities[key] = entity

if new_entities:
_LOGGER.debug("Discovered new entities %r", new_entities)
self._platform_entities.update(new_entities)
self._pending_entities.clear()

# Compute a new primary entity
self._compute_primary_entity(all_entities.values())

# Finally, add the new entities
for entity in new_entities.values():
self._add_entity(entity)

async def recompute_entities(self) -> None:
"""Recompute all entities for this device."""
self.debug("Recomputing entities")

entities = list(self._platform_entities.values())

# Remove all entities that are no longer supported
for entity in entities[:]:
entity.recompute_capabilities()

if not entity.is_supported() or not entity.is_supported_in_list(entities):
self.debug("Removing unsupported entity %s", entity)
await self._remove_entity(entity)
entities.remove(entity)

# Discover new entities
self._discover_new_entities()
await self._add_pending_entities()

async def async_initialize(self, from_cache: bool = False) -> None:
"""Initialize cluster handlers."""
self.debug("started initialization")

# We discover prospective entities before initialization
self._discover_new_entities()

# At this point we can compute a primary entity
self._compute_primary_entity()
await self._zdo_handler.async_initialize(from_cache)
self._zdo_handler.debug("'async_initialize' stage succeeded")

# We intentionally do not use `gather` here! This is so that if, for example,
# three `device.async_initialize()`s are spawned, only three concurrent requests
# will ever be in flight at once. Startup concurrency is managed at the device
# level.
for endpoint in self._endpoints.values():
try:
await endpoint.async_initialize(from_cache)
except Exception: # pylint: disable=broad-exception-caught
self.debug("Failed to initialize endpoint", exc_info=True)

# And add them after
await self._add_pending_entities()

# Sync the device's firmware version with the first platform entity
for (platform, _unique_id), entity in self.platform_entities.items():
Expand Down Expand Up @@ -1023,8 +1114,10 @@ async def on_remove(self) -> None:
for callback in self._on_remove_callbacks:
callback()

for platform_entity in self._platform_entities.values():
await platform_entity.on_remove()
for platform_entity in list(self._platform_entities.values()):
# TODO: To avoid unnecessary traffic during shutdown, we don't need to emit
# an event for every entity, just the device
await self._remove_entity(platform_entity, emit_event=False)

for entity in self._pending_entities:
await entity.on_remove()
Expand Down Expand Up @@ -1329,13 +1422,11 @@ def log(self, level: int, msg: str, *args: Any, **kwargs: Any) -> None:
args = (self.nwk, self.model) + args
_LOGGER.log(level, msg, *args, **kwargs)

def _compute_primary_entity(self) -> None:
"""Compute the primary entity for this device."""
def _compute_primary_entity(self, entities: Sequence[PlatformEntity]) -> None:
"""Compute the primary entity from a given set of entities."""

# First, check if any entity is explicitly primary
explicitly_primary = [
entity for entity in self._platform_entities.values() if entity.primary
]
explicitly_primary = [entity for entity in entities if entity.primary]

if len(explicitly_primary) == 1:
self.debug(
Expand All @@ -1351,7 +1442,7 @@ def _compute_primary_entity(self) -> None:
# not explicitly marked as not primary
candidates = [
e
for e in self._platform_entities.values()
for e in entities
if e.enabled and hasattr(e, "info_object") and e._attr_primary is not False
]
candidates.sort(reverse=True, key=lambda e: e.primary_weight)
Expand Down
Loading