Skip to content

Commit

Permalink
coverage and adjust shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
dmulcahey committed Apr 3, 2024
1 parent d16cfea commit f55df34
Show file tree
Hide file tree
Showing 5 changed files with 420 additions and 20 deletions.
151 changes: 150 additions & 1 deletion tests/test_async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import asyncio
import functools
import time
from unittest.mock import MagicMock, patch
from typing import Any
from unittest.mock import MagicMock, Mock, patch

import pytest

from zha import async_ as zha_async
from zha.application.gateway import Gateway
from zha.async_ import AsyncUtilMixin, ZHAJob, ZHAJobType, create_eager_task
from zha.decorators import callback
Expand Down Expand Up @@ -493,6 +495,9 @@ async def test_add_job_with_none(zha_gateway: Gateway) -> None:
with pytest.raises(ValueError):
zha_gateway.async_add_job(None, "test_arg")

with pytest.raises(ValueError):
zha_gateway.add_job(None, "test_arg")


async def test_async_functions_with_callback(zha_gateway: Gateway) -> None:
"""Test we deal with async functions accidentally marked as callback."""
Expand Down Expand Up @@ -693,3 +698,147 @@ async def _async_add_executor_job():
await zha_gateway.async_block_till_done()
assert len(calls) == 1
await task


@patch("concurrent.futures.Future")
@patch("threading.get_ident")
def test_run_callback_threadsafe_from_inside_event_loop(mock_ident, _) -> None:
"""Testing calling run_callback_threadsafe from inside an event loop."""
callback_fn = MagicMock()

loop = Mock(spec=["call_soon_threadsafe"])

loop._thread_ident = None
mock_ident.return_value = 5
zha_async.run_callback_threadsafe(loop, callback_fn)
assert len(loop.call_soon_threadsafe.mock_calls) == 1

loop._thread_ident = 5
mock_ident.return_value = 5
with pytest.raises(RuntimeError):
zha_async.run_callback_threadsafe(loop, callback_fn)
assert len(loop.call_soon_threadsafe.mock_calls) == 1

loop._thread_ident = 1
mock_ident.return_value = 5
zha_async.run_callback_threadsafe(loop, callback_fn)
assert len(loop.call_soon_threadsafe.mock_calls) == 2


async def test_gather_with_limited_concurrency() -> None:
"""Test gather_with_limited_concurrency limits the number of running tasks."""

runs = 0
now_time = time.time()

async def _increment_runs_if_in_time():
if time.time() - now_time > 0.1:
return -1

nonlocal runs
runs += 1
await asyncio.sleep(0.1)
return runs

results = await zha_async.gather_with_limited_concurrency(
2, *(_increment_runs_if_in_time() for i in range(4))
)

assert results == [2, 2, -1, -1]


async def test_shutdown_run_callback_threadsafe(zha_gateway: Gateway) -> None:
"""Test we can shutdown run_callback_threadsafe."""
zha_async.shutdown_run_callback_threadsafe(zha_gateway.loop)
callback_fn = MagicMock()

with pytest.raises(RuntimeError):
zha_async.run_callback_threadsafe(zha_gateway.loop, callback_fn)


async def test_run_callback_threadsafe(zha_gateway: Gateway) -> None:
"""Test run_callback_threadsafe runs code in the event loop."""
it_ran = False

def callback_fn():
nonlocal it_ran
it_ran = True

assert zha_async.run_callback_threadsafe(zha_gateway.loop, callback_fn)
assert it_ran is False

# Verify that async_block_till_done will flush
# out the callback
await zha_gateway.async_block_till_done()
assert it_ran is True


async def test_callback_is_always_scheduled(zha_gateway: Gateway) -> None:
"""Test run_callback_threadsafe always calls call_soon_threadsafe before checking for shutdown."""
# We have to check the shutdown state AFTER the callback is scheduled otherwise
# the function could continue on and the caller call `future.result()` after
# the point in the main thread where callbacks are no longer run.

callback_fn = MagicMock()
zha_async.shutdown_run_callback_threadsafe(zha_gateway.loop)

with (
patch.object(
zha_gateway.loop, "call_soon_threadsafe"
) as mock_call_soon_threadsafe,
pytest.raises(RuntimeError),
):
zha_async.run_callback_threadsafe(zha_gateway.loop, callback_fn)

mock_call_soon_threadsafe.assert_called_once()


async def test_create_eager_task_312(zha_gateway: Gateway) -> None: # pylint: disable=unused-argument
"""Test create_eager_task schedules a task eagerly in the event loop.
For Python 3.12+, the task is scheduled eagerly in the event loop.
"""
events = []

async def _normal_task():
events.append("normal")

async def _eager_task():
events.append("eager")

task1 = zha_async.create_eager_task(_eager_task())
task2 = asyncio.create_task(_normal_task())

assert events == ["eager"]

await asyncio.sleep(0)
assert events == ["eager", "normal"]
await task1
await task2


async def test_shutdown_calls_block_till_done_after_shutdown_run_callback_threadsafe(
zha_gateway: Gateway,
) -> None:
"""Ensure shutdown_run_callback_threadsafe is called before the final async_block_till_done."""
stop_calls: list[Any] = []

async def _record_block_till_done(wait_background_tasks: bool = False): # pylint: disable=unused-argument
nonlocal stop_calls
stop_calls.append("async_block_till_done")

def _record_shutdown_run_callback_threadsafe(loop):
nonlocal stop_calls
stop_calls.append(("shutdown_run_callback_threadsafe", loop))

with (
patch.object(zha_gateway, "async_block_till_done", _record_block_till_done),
patch(
"zha.async_.shutdown_run_callback_threadsafe",
_record_shutdown_run_callback_threadsafe,
),
):
await zha_gateway.shutdown()

assert stop_calls[-2] == ("shutdown_run_callback_threadsafe", zha_gateway.loop)
assert stop_calls[-1] == "async_block_till_done"
2 changes: 2 additions & 0 deletions tests/test_cluster_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,9 @@ async def test_poll_control_ikea(poll_control_device: Device) -> None:
poll_control_ch = poll_control_device._endpoints[1].all_cluster_handlers["1:0x0020"]
cluster = poll_control_ch.cluster

delattr(poll_control_device, "manufacturer_code")
poll_control_device.device.node_desc.manufacturer_code = 4476

with mock.patch.object(cluster, "set_long_poll_interval", set_long_poll_mock):
await poll_control_ch.check_in_response(33)

Expand Down
111 changes: 111 additions & 0 deletions tests/test_thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Test ZHA thread utils."""

import asyncio
from unittest.mock import Mock, patch

import pytest

from zha import async_
from zha.application.gateway import Gateway
from zha.async_ import ThreadWithException, run_callback_threadsafe


async def test_thread_with_exception_invalid(zha_gateway: Gateway) -> None:
"""Test throwing an invalid thread exception."""

finish_event = asyncio.Event()

def _do_nothing(*_):
run_callback_threadsafe(zha_gateway.loop, finish_event.set)

test_thread = ThreadWithException(target=_do_nothing)
test_thread.start()
await asyncio.wait_for(finish_event.wait(), timeout=0.1)

with pytest.raises(TypeError):
test_thread.raise_exc(_EmptyClass())
test_thread.join()


async def test_thread_not_started(zha_gateway: Gateway) -> None:
"""Test throwing when the thread is not started."""

test_thread = ThreadWithException(target=lambda *_: None)

with pytest.raises(AssertionError):
test_thread.raise_exc(TimeoutError)


async def test_thread_fails_raise(zha_gateway: Gateway) -> None:
"""Test throwing after already ended."""

finish_event = asyncio.Event()

def _do_nothing(*_):
run_callback_threadsafe(zha_gateway.loop, finish_event.set)

test_thread = ThreadWithException(target=_do_nothing)
test_thread.start()
await asyncio.wait_for(finish_event.wait(), timeout=0.1)
test_thread.join()

with pytest.raises(SystemError):
test_thread.raise_exc(ValueError)


class _EmptyClass:
"""An empty class."""


async def test_deadlock_safe_shutdown_no_threads() -> None:
"""Test we can shutdown without deadlock without any threads to join."""

dead_thread_mock = Mock(
join=Mock(), daemon=False, is_alive=Mock(return_value=False)
)
daemon_thread_mock = Mock(
join=Mock(), daemon=True, is_alive=Mock(return_value=True)
)
mock_threads = [
dead_thread_mock,
daemon_thread_mock,
]

with patch("threading.enumerate", return_value=mock_threads):
async_.deadlock_safe_shutdown()

assert not dead_thread_mock.join.called
assert not daemon_thread_mock.join.called


async def test_deadlock_safe_shutdown() -> None:
"""Test we can shutdown without deadlock."""

normal_thread_mock = Mock(
join=Mock(), daemon=False, is_alive=Mock(return_value=True)
)
dead_thread_mock = Mock(
join=Mock(), daemon=False, is_alive=Mock(return_value=False)
)
daemon_thread_mock = Mock(
join=Mock(), daemon=True, is_alive=Mock(return_value=True)
)
exception_thread_mock = Mock(
join=Mock(side_effect=Exception), daemon=False, is_alive=Mock(return_value=True)
)
mock_threads = [
normal_thread_mock,
dead_thread_mock,
daemon_thread_mock,
exception_thread_mock,
]

with patch("threading.enumerate", return_value=mock_threads):
async_.deadlock_safe_shutdown()

expected_timeout = async_.THREADING_SHUTDOWN_TIMEOUT / 2

assert normal_thread_mock.join.call_args[0] == (expected_timeout,)
assert not dead_thread_mock.join.called
assert not daemon_thread_mock.join.called
assert exception_thread_mock.join.call_args[0] == (expected_timeout,)
21 changes: 5 additions & 16 deletions zha/application/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import asyncio
from collections.abc import Iterable
from contextlib import suppress
from dataclasses import dataclass
from datetime import timedelta
Expand Down Expand Up @@ -688,31 +687,21 @@ async def shutdown(self) -> None:
_LOGGER.debug("Ignoring duplicate shutdown event")
return

self.shutting_down = True

self.global_updater.stop()
self._device_availability_checker.stop()

async def _cancel_tasks(tasks_to_cancel: Iterable) -> None:
tasks = [t for t in tasks_to_cancel if not (t.done() or t.cancelled())]
for task in tasks:
_LOGGER.debug("Cancelling task: %s", task)
task.cancel()
with suppress(asyncio.CancelledError):
await asyncio.gather(*tasks, return_exceptions=True)

await _cancel_tasks(self._background_tasks)
await _cancel_tasks(self._tracked_completable_tasks)
await _cancel_tasks(self._device_init_tasks.values())
super().async_shutdown()

for device in self._devices.values():
await device.on_remove()

_LOGGER.debug("Shutting down ZHA ControllerApplication")
self.shutting_down = True

await self.application_controller.shutdown()
self.application_controller = None
await asyncio.sleep(0.1) # give bellows thread callback a chance to run

await super().shutdown()

self._devices.clear()
self._groups.clear()

Expand Down
Loading

0 comments on commit f55df34

Please sign in to comment.