Skip to content

Commit

Permalink
Always call shutdown on async_initialize failure (#123)
Browse files Browse the repository at this point in the history
* Always call `shutdown`

* add test

---------

Co-authored-by: David Mulcahey <[email protected]>
  • Loading branch information
puddly and dmulcahey authored Aug 3, 2024
1 parent 237d1a7 commit 68b43ef
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
18 changes: 18 additions & 0 deletions tests/test_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,24 @@ async def test_device_left(
assert zha_dev_basic.on_network is False


async def test_gateway_startup_failure(
zha_data: ZHAData,
) -> None:
"""Test shutdown called when gateway init fails."""

zha_gateway = await Gateway.async_from_config(zha_data)

with (
patch("zha.application.gateway.Gateway.load_devices", side_effect=Exception),
pytest.raises(Exception),
):
zha_gateway.shutdown = AsyncMock(wraps=zha_gateway.shutdown)
await zha_gateway.async_initialize()
await zha_gateway.async_block_till_done()

assert zha_gateway.shutdown.await_count == 1


async def test_gateway_group_methods(
zha_gateway: Gateway,
device_light_1, # pylint: disable=redefined-outer-name
Expand Down
17 changes: 10 additions & 7 deletions zha/application/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ async def async_from_config(cls, config: ZHAData) -> Self:

return instance

async def async_initialize(self) -> None:
async def _async_initialize(self) -> None:
"""Initialize controller and connect radio."""
discovery.DEVICE_PROBE.initialize(self)
discovery.ENDPOINT_PROBE.initialize(self)
Expand All @@ -239,12 +239,7 @@ async def async_initialize(self) -> None:
start_radio=False,
)

try:
await self.application_controller.startup(auto_form=True)
except Exception:
# Explicitly shut down the controller application on failure
await self.application_controller.shutdown()
raise
await self.application_controller.startup(auto_form=True)

self.coordinator_zha_device = self.get_or_create_device(
self._find_coordinator_device()
Expand All @@ -258,6 +253,14 @@ async def async_initialize(self) -> None:
self.global_updater.start()
self._device_availability_checker.start()

async def async_initialize(self) -> None:
"""Initialize controller and connect radio."""
try:
await self._async_initialize()
except Exception:
await self.shutdown()
raise

def connection_lost(self, exc: Exception) -> None:
"""Handle connection lost event."""
_LOGGER.debug("Connection to the radio was lost: %r", exc)
Expand Down

0 comments on commit 68b43ef

Please sign in to comment.