diff --git a/tests/test_gateway.py b/tests/test_gateway.py index 1b32608f..fc2d8da7 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -166,6 +166,24 @@ async def test_device_left( assert zha_dev_basic.on_network is False +async def test_gateway_startup_failure( + zha_data: ZHAData, +) -> None: + """Test shutdown called when gateway init fails.""" + + zha_gateway = await Gateway.async_from_config(zha_data) + + with ( + patch("zha.application.gateway.Gateway.load_devices", side_effect=Exception), + pytest.raises(Exception), + ): + zha_gateway.shutdown = AsyncMock(wraps=zha_gateway.shutdown) + await zha_gateway.async_initialize() + await zha_gateway.async_block_till_done() + + assert zha_gateway.shutdown.await_count == 1 + + async def test_gateway_group_methods( zha_gateway: Gateway, device_light_1, # pylint: disable=redefined-outer-name diff --git a/zha/application/gateway.py b/zha/application/gateway.py index 17059658..5ba7ff3e 100644 --- a/zha/application/gateway.py +++ b/zha/application/gateway.py @@ -224,7 +224,7 @@ async def async_from_config(cls, config: ZHAData) -> Self: return instance - async def async_initialize(self) -> None: + async def _async_initialize(self) -> None: """Initialize controller and connect radio.""" discovery.DEVICE_PROBE.initialize(self) discovery.ENDPOINT_PROBE.initialize(self) @@ -239,12 +239,7 @@ async def async_initialize(self) -> None: start_radio=False, ) - try: - await self.application_controller.startup(auto_form=True) - except Exception: - # Explicitly shut down the controller application on failure - await self.application_controller.shutdown() - raise + await self.application_controller.startup(auto_form=True) self.coordinator_zha_device = self.get_or_create_device( self._find_coordinator_device() @@ -258,6 +253,14 @@ async def async_initialize(self) -> None: self.global_updater.start() self._device_availability_checker.start() + async def async_initialize(self) -> None: + """Initialize controller and connect radio.""" + try: + await self._async_initialize() + except Exception: + await self.shutdown() + raise + def connection_lost(self, exc: Exception) -> None: """Handle connection lost event.""" _LOGGER.debug("Connection to the radio was lost: %r", exc)