diff --git a/zha/application/platforms/update.py b/zha/application/platforms/update.py index 86789f73..5650798f 100644 --- a/zha/application/platforms/update.py +++ b/zha/application/platforms/update.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from enum import IntFlag, StrEnum import functools +import itertools import logging from typing import TYPE_CHECKING, Any, Final, final @@ -110,7 +111,9 @@ def __init__( ] self._attr_installed_version: str | None = self._get_cluster_version() self._attr_latest_version = self._attr_installed_version - self._latest_firmware: OtaImageWithMetadata | None = None + self._compatible_images: OtaImagesResult = OtaImagesResult( + upgrades=(), downgrades=() + ) self.device.device.add_listener(self) self._ota_cluster_handler.on_event( @@ -205,16 +208,14 @@ def state_attributes(self) -> dict[str, Any] | None: } @final - async def async_install_with_progress( - self, version: str | None, backup: bool - ) -> None: + async def async_install_with_progress(self, version: str | None) -> None: """Install update and handle progress if needed. Handles setting the in_progress state in case the entity doesn't support it natively. """ try: - await self.async_install(version, backup) + await self.async_install(version=version) finally: # No matter what happens, we always stop progress in the end self._attr_in_progress = False @@ -244,37 +245,35 @@ def device_ota_image_query_result( ) -> None: """Handle ota update available signal from Zigpy.""" - _LOGGER.debug("Got OTA result: %s - %s", images_result, query_next_img_command) - current_version = query_next_img_command.current_file_version self._attr_installed_version = f"0x{current_version:08x}" - self._latest_firmware = None + self._compatible_images = images_result self._attr_latest_version = None self._attr_release_summary = None self._attr_release_notes = None + latest_firmware: OtaImageWithMetadata | None = None + if images_result.upgrades: # If there are upgrades, cache the image and indicate that we should upgrade - self._latest_firmware = images_result.upgrades[0] - self._attr_latest_version = f"0x{self._latest_firmware.version:08x}" + latest_firmware = images_result.upgrades[0] + self._attr_latest_version = f"0x{latest_firmware.version:08x}" + self._attr_release_summary = ( + latest_firmware.metadata.changelog + if latest_firmware.metadata.changelog + else None + ) + self._attr_release_notes = ( + latest_firmware.metadata.release_notes + if latest_firmware.metadata.release_notes + else None + ) elif images_result.downgrades: # If not, note the version of the most recent firmware - self._latest_firmware = None + latest_firmware = None self._attr_latest_version = f"0x{images_result.downgrades[0].version:08x}" - if ( - self._latest_firmware is not None - and self._latest_firmware.metadata.changelog - ): - self._attr_release_summary = self._latest_firmware.metadata.changelog - - if ( - self._latest_firmware is not None - and self._latest_firmware.metadata.release_notes - ): - self._attr_release_notes = self._latest_firmware.metadata.release_notes - self.maybe_emit_state_changed_event() def _update_progress(self, current: int, total: int, progress: float) -> None: @@ -286,11 +285,25 @@ def _update_progress(self, current: int, total: int, progress: float) -> None: self._attr_progress = int(progress) self.maybe_emit_state_changed_event() - async def async_install( - self, version: str | None, backup: bool, **kwargs: Any - ) -> None: + async def async_install(self, version: str | None) -> None: """Install an update.""" - assert self._latest_firmware is not None + + if version is None: + if not self._compatible_images.upgrades: + raise ZHAException("No firmware updates are available") + + firmware = self._compatible_images.upgrades[0] + else: + version = int(version, 16) + + for firmware in itertools.chain( + self._compatible_images.upgrades, + self._compatible_images.downgrades, + ): + if firmware.version == version: + break + else: + raise ZHAException(f"Version {version!r} is not available") self._attr_in_progress = True self._attr_progress = 0 @@ -298,7 +311,7 @@ async def async_install( try: result = await self.device.device.update_firmware( - image=self._latest_firmware, + image=firmware, progress_callback=self._update_progress, ) except Exception as ex: @@ -320,7 +333,6 @@ async def async_install( raise ZHAException(f"Update was not successful: {result}") # Clear the state - self._latest_firmware = None self._attr_in_progress = False self.maybe_emit_state_changed_event()