diff --git a/tests/test_device.py b/tests/test_device.py index 3670c1f3c..f67fe2bd5 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -18,6 +18,7 @@ from zigpy.zcl import ClusterType from zigpy.zcl.clusters import general from zigpy.zcl.clusters.general import Ota, PowerConfiguration +from zigpy.zcl.clusters.lighting import Color from zigpy.zcl.foundation import Status, WriteAttributesResponse import zigpy.zdo.types as zdo_t @@ -49,6 +50,8 @@ from zha.exceptions import ZHAException from zha.zigbee.device import ( ClusterBinding, + DeviceEntityAddedEvent, + DeviceEntityRemovedEvent, DeviceFirmwareInfoUpdatedEvent, ZHAEvent, get_device_automation_triggers, @@ -1201,3 +1204,61 @@ async def test_symfonisk_events( ) ) ] + + +async def test_entity_recomputation(zha_gateway: Gateway) -> None: + """Test entity recomputation.""" + zigpy_dev = await zigpy_device_from_json( + zha_gateway.application_controller, + "tests/data/devices/ikea-of-sweden-tradfri-bulb-gu10-ws-400lm.json", + ) + zha_device = await join_zigpy_device(zha_gateway, zigpy_dev) + + event_listener = mock.Mock() + zha_device.on_all_events(event_listener) + + entities1 = set(zha_device.platform_entities.values()) + + # We lose track of the color temperature + zha_device._zigpy_device.endpoints[1].light_color.add_unsupported_attribute( + Color.AttributeDefs.start_up_color_temperature.id + ) + await zha_device.recompute_entities() + + entities2 = set(zha_device.platform_entities.values()) + assert entities2 - entities1 == set() + assert len(entities1 - entities2) == 1 + assert ( + list(entities1 - entities2)[0].unique_id + == "68:0a:e2:ff:fe:8f:fa:33-1-768-start_up_color_temperature" + ) + assert event_listener.mock_calls == [ + call( + DeviceEntityRemovedEvent( + unique_id="68:0a:e2:ff:fe:8f:fa:33-1-768-start_up_color_temperature" + ) + ) + ] + + event_listener.reset_mock() + + # We add it back + zha_device._zigpy_device.endpoints[1].light_color.remove_unsupported_attribute( + Color.AttributeDefs.start_up_color_temperature.id + ) + await zha_device.recompute_entities() + + entities3 = set(zha_device.platform_entities.values()) + assert ( + list(entities3 - entities2)[0].unique_id + == "68:0a:e2:ff:fe:8f:fa:33-1-768-start_up_color_temperature" + ) + assert {e.unique_id for e in entities1} == {e.unique_id for e in entities3} + + assert event_listener.mock_calls == [ + call( + DeviceEntityAddedEvent( + unique_id="68:0a:e2:ff:fe:8f:fa:33-1-768-start_up_color_temperature" + ) + ) + ] diff --git a/tests/test_discover.py b/tests/test_discover.py index ea8e9fd57..30a85a7d0 100644 --- a/tests/test_discover.py +++ b/tests/test_discover.py @@ -807,6 +807,9 @@ async def test_devices_from_files( await zha_gateway.async_block_till_done(wait_background_tasks=True) assert zha_device is not None + # Ensure entity recomputation is idempotent + await zha_device.recompute_entities() + unique_id_collisions = defaultdict(list) for entity in zha_device.platform_entities.values(): unique_id_collisions[entity.unique_id].append(entity) @@ -841,8 +844,6 @@ async def test_devices_from_files( unique_id_migrations[key] = entity - await zha_device.on_remove() - # XXX: We re-serialize the JSON because integer enum types are converted when # serializing but will not compare properly otherwise loaded_device_data = json.loads( @@ -871,3 +872,5 @@ async def test_devices_from_files( tsn=None, ) ] + + await zha_device.on_remove() diff --git a/zha/application/const.py b/zha/application/const.py index 08c857afb..8ac116fc2 100644 --- a/zha/application/const.py +++ b/zha/application/const.py @@ -193,6 +193,8 @@ def pretty_name(self) -> str: ZHA_CLUSTER_HANDLER_READS_PER_REQ = 5 ZHA_EVENT = "zha_event" ZHA_DEVICE_UPDATED_EVENT = "zha_device_updated_event" +ZHA_DEVICE_ENTITY_ADDED_EVENT = "zha_device_entity_added_event" +ZHA_DEVICE_ENTITY_REMOVED_EVENT = "zha_device_entity_removed_event" ZHA_GW_MSG = "zha_gateway_message" ZHA_GW_MSG_DEVICE_FULL_INIT = "device_fully_initialized" ZHA_GW_MSG_DEVICE_INFO = "device_info" diff --git a/zha/application/platforms/button/__init__.py b/zha/application/platforms/button/__init__.py index ed8293f21..d7ed4d565 100644 --- a/zha/application/platforms/button/__init__.py +++ b/zha/application/platforms/button/__init__.py @@ -124,7 +124,7 @@ class IdentifyButton(Button): def is_supported_in_list(self, entities: list[BaseEntity]) -> bool: """Check if this button is supported given the list of entities.""" cls = type(self) - return not any(type(entity) is cls for entity in entities) + return not any(type(entity) is cls for entity in entities if entity is not self) class WriteAttributeButton(PlatformEntity): diff --git a/zha/application/platforms/sensor/__init__.py b/zha/application/platforms/sensor/__init__.py index fcfa8558e..730314d33 100644 --- a/zha/application/platforms/sensor/__init__.py +++ b/zha/application/platforms/sensor/__init__.py @@ -1641,7 +1641,7 @@ def _is_supported(self) -> bool: def is_supported_in_list(self, entities: list[BaseEntity]) -> bool: """Check if the sensor is supported given the list of entities.""" cls = type(self) - return not any(type(entity) is cls for entity in entities) + return not any(type(entity) is cls for entity in entities if entity is not self) @property def state(self) -> dict: diff --git a/zha/zigbee/device.py b/zha/zigbee/device.py index af0202a2d..df19e8333 100644 --- a/zha/zigbee/device.py +++ b/zha/zigbee/device.py @@ -6,7 +6,7 @@ import asyncio from collections import defaultdict -from collections.abc import Callable, Iterable, Iterator +from collections.abc import Callable, Iterable, Iterator, Sequence import copy import dataclasses from dataclasses import dataclass @@ -61,6 +61,8 @@ UNKNOWN_MODEL, ZHA_CLUSTER_HANDLER_CFG_DONE, ZHA_CLUSTER_HANDLER_MSG, + ZHA_DEVICE_ENTITY_ADDED_EVENT, + ZHA_DEVICE_ENTITY_REMOVED_EVENT, ZHA_DEVICE_UPDATED_EVENT, ZHA_EVENT, ) @@ -166,6 +168,27 @@ class DeviceFirmwareInfoUpdatedEvent: new_firmware_version: str | None +@dataclass(kw_only=True, frozen=True) +class DeviceEntityAddedEvent: + """Event generated when a new entity is added to a device.""" + + event_type: Final[str] = ZHA_DEVICE_ENTITY_ADDED_EVENT + event: Final[str] = ZHA_DEVICE_ENTITY_ADDED_EVENT + + # TODO: allow all entity information to be serialized and include it here + unique_id: str + + +@dataclass(kw_only=True, frozen=True) +class DeviceEntityRemovedEvent: + """Event generated when a new entity is added to a device.""" + + event_type: Final[str] = ZHA_DEVICE_ENTITY_REMOVED_EVENT + event: Final[str] = ZHA_DEVICE_ENTITY_REMOVED_EVENT + + unique_id: str + + @dataclass(kw_only=True, frozen=True) class ClusterHandlerConfigurationComplete: """Event generated when all cluster handlers are configured.""" @@ -946,26 +969,49 @@ def _discover_new_entities(self) -> None: entity.on_add() self._pending_entities.append(entity) - async def async_initialize(self, from_cache: bool = False) -> None: - """Initialize cluster handlers.""" - self.debug("started initialization") + def _add_entity(self, entity: PlatformEntity) -> None: + """Add an entity to the device.""" + key = (entity.PLATFORM, entity.unique_id) - self._discover_new_entities() + if key in self._platform_entities: + raise ValueError( + f"Cannot add entity {entity!r}, unique ID already taken by {self._platform_entities[key]!r}" + ) - await self._zdo_handler.async_initialize(from_cache) - self._zdo_handler.debug("'async_initialize' stage succeeded") + _LOGGER.debug("Discovered new entity %s", entity) - # We intentionally do not use `gather` here! This is so that if, for example, - # three `device.async_initialize()`s are spawned, only three concurrent requests - # will ever be in flight at once. Startup concurrency is managed at the device - # level. - for endpoint in self._endpoints.values(): - try: - await endpoint.async_initialize(from_cache) - except Exception: # pylint: disable=broad-exception-caught - self.debug("Failed to initialize endpoint", exc_info=True) + # `entity.on_add()` is assumed to have been called already + self._platform_entities[key] = entity + self.emit( + DeviceEntityAddedEvent.event_type, + DeviceEntityAddedEvent( + unique_id=entity.unique_id, + ), + ) + + async def _remove_entity( + self, entity: BaseEntity, *, emit_event: bool = True + ) -> None: + """Remove an entity from the device.""" + key = (entity.PLATFORM, entity.unique_id) + + if key not in self._platform_entities: + raise ValueError(f"Cannot remove entity {entity!r}, unique ID not found") - # Compute the final entities + await entity.on_remove() + del self._platform_entities[key] + + if emit_event: + self.emit( + DeviceEntityRemovedEvent.event_type, + DeviceEntityRemovedEvent( + unique_id=entity.unique_id, + ), + ) + + async def _add_pending_entities(self) -> None: + """Add pending entities to the device.""" + all_entities = dict(self._platform_entities) new_entities: dict[tuple[Platform, str], PlatformEntity] = {} for entity in self._pending_entities: @@ -973,7 +1019,7 @@ async def async_initialize(self, from_cache: bool = False) -> None: # Ignore unsupported entities if not entity.is_supported() or not entity.is_supported_in_list( - new_entities.values() + all_entities.values() ): await entity.on_remove() continue @@ -981,18 +1027,63 @@ async def async_initialize(self, from_cache: bool = False) -> None: key = (entity.PLATFORM, entity.unique_id) # Ignore entities that already exist - if key in new_entities: + if key in all_entities: await entity.on_remove() continue + all_entities[key] = entity new_entities[key] = entity - if new_entities: - _LOGGER.debug("Discovered new entities %r", new_entities) - self._platform_entities.update(new_entities) + self._pending_entities.clear() + + # Compute a new primary entity + self._compute_primary_entity(all_entities.values()) + + # Finally, add the new entities + for entity in new_entities.values(): + self._add_entity(entity) + + async def recompute_entities(self) -> None: + """Recompute all entities for this device.""" + self.debug("Recomputing entities") + + entities = list(self._platform_entities.values()) + + # Remove all entities that are no longer supported + for entity in entities[:]: + entity.recompute_capabilities() + + if not entity.is_supported() or not entity.is_supported_in_list(entities): + self.debug("Removing unsupported entity %s", entity) + await self._remove_entity(entity) + entities.remove(entity) + + # Discover new entities + self._discover_new_entities() + await self._add_pending_entities() + + async def async_initialize(self, from_cache: bool = False) -> None: + """Initialize cluster handlers.""" + self.debug("started initialization") + + # We discover prospective entities before initialization + self._discover_new_entities() - # At this point we can compute a primary entity - self._compute_primary_entity() + await self._zdo_handler.async_initialize(from_cache) + self._zdo_handler.debug("'async_initialize' stage succeeded") + + # We intentionally do not use `gather` here! This is so that if, for example, + # three `device.async_initialize()`s are spawned, only three concurrent requests + # will ever be in flight at once. Startup concurrency is managed at the device + # level. + for endpoint in self._endpoints.values(): + try: + await endpoint.async_initialize(from_cache) + except Exception: # pylint: disable=broad-exception-caught + self.debug("Failed to initialize endpoint", exc_info=True) + + # And add them after + await self._add_pending_entities() # Sync the device's firmware version with the first platform entity for (platform, _unique_id), entity in self.platform_entities.items(): @@ -1023,8 +1114,10 @@ async def on_remove(self) -> None: for callback in self._on_remove_callbacks: callback() - for platform_entity in self._platform_entities.values(): - await platform_entity.on_remove() + for platform_entity in list(self._platform_entities.values()): + # TODO: To avoid unnecessary traffic during shutdown, we don't need to emit + # an event for every entity, just the device + await self._remove_entity(platform_entity, emit_event=False) for entity in self._pending_entities: await entity.on_remove() @@ -1329,13 +1422,11 @@ def log(self, level: int, msg: str, *args: Any, **kwargs: Any) -> None: args = (self.nwk, self.model) + args _LOGGER.log(level, msg, *args, **kwargs) - def _compute_primary_entity(self) -> None: - """Compute the primary entity for this device.""" + def _compute_primary_entity(self, entities: Sequence[PlatformEntity]) -> None: + """Compute the primary entity from a given set of entities.""" # First, check if any entity is explicitly primary - explicitly_primary = [ - entity for entity in self._platform_entities.values() if entity.primary - ] + explicitly_primary = [entity for entity in entities if entity.primary] if len(explicitly_primary) == 1: self.debug( @@ -1351,7 +1442,7 @@ def _compute_primary_entity(self) -> None: # not explicitly marked as not primary candidates = [ e - for e in self._platform_entities.values() + for e in entities if e.enabled and hasattr(e, "info_object") and e._attr_primary is not False ] candidates.sort(reverse=True, key=lambda e: e.primary_weight)