Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] dedicated thread for zigpy-znp - the bellows approach #184

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions zigpy_znp/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
CallbackResponseListener,
)
from zigpy_znp.frames import GeneralFrame
from zigpy_znp.exceptions import CommandNotRecognized, InvalidCommandResponse
from zigpy_znp.exceptions import (
InvalidFrame,
CommandNotRecognized,
InvalidCommandResponse,
)
from zigpy_znp.types.nvids import ExNvIds, OsalNvIds

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -715,7 +719,7 @@ async def connect(self, *, test_port=True) -> None:
self.close()
raise

LOGGER.debug("Connected to %s", self._uart.url)
LOGGER.debug("Connected to %s", self._uart.get_url())

def connection_made(self) -> None:
"""
Expand Down Expand Up @@ -792,7 +796,7 @@ def remove_listener(self, listener: BaseResponseListener) -> None:
counts[OneShotResponseListener],
)

def frame_received(self, frame: GeneralFrame) -> bool | None:
def frame_received(self, frame: GeneralFrame) -> None:
"""
Called when a frame has been received. Returns whether or not the frame was
handled by any listener.
Expand All @@ -802,7 +806,7 @@ def frame_received(self, frame: GeneralFrame) -> bool | None:

if frame.header not in c.COMMANDS_BY_ID:
LOGGER.error("Received an unknown frame: %s", frame)
return False
raise InvalidFrame("Invalid command id")

command_cls = c.COMMANDS_BY_ID[frame.header]

Expand All @@ -813,7 +817,9 @@ def frame_received(self, frame: GeneralFrame) -> bool | None:
# https://github.com/home-assistant/core/issues/50005
if command_cls == c.ZDO.ParentAnnceRsp.Callback:
LOGGER.warning("Failed to parse broken %s as %s", frame, command_cls)
return False
raise InvalidFrame(
"Parsing frame %s ad command %s failed", frame, command_cls
)

raise

Expand Down Expand Up @@ -844,8 +850,6 @@ def frame_received(self, frame: GeneralFrame) -> bool | None:
if not matched:
self._unhandled_command(command)

return matched

def _unhandled_command(self, command: t.CommandBase):
"""
Called when a command that is not handled by any listener is received.
Expand Down
122 changes: 122 additions & 0 deletions zigpy_znp/thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
import functools
import logging
import sys

LOGGER = logging.getLogger(__name__)


class EventLoopThread:
"""Run a parallel event loop in a separate thread."""

def __init__(self):
self.loop = None
self.thread_complete = None

def run_coroutine_threadsafe(self, coroutine):
current_loop = asyncio.get_event_loop()
future = asyncio.run_coroutine_threadsafe(coroutine, self.loop)
return asyncio.wrap_future(future, loop=current_loop)

def _thread_main(self, init_task):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)

try:
self.loop.run_until_complete(init_task)
self.loop.run_forever()
finally:
self.loop.close()
self.loop = None

async def start(self):
current_loop = asyncio.get_event_loop()
if self.loop is not None and not self.loop.is_closed():
return

executor_opts = {"max_workers": 1}
if sys.version_info[:2] >= (3, 6):
executor_opts["thread_name_prefix"] = __name__
executor = ThreadPoolExecutor(**executor_opts)

thread_started_future = current_loop.create_future()

async def init_task():
current_loop.call_soon_threadsafe(thread_started_future.set_result, None)

# Use current loop so current loop has a reference to the long-running thread
# as one of its tasks
thread_complete = current_loop.run_in_executor(
executor, self._thread_main, init_task()
)
self.thread_complete = thread_complete
current_loop.call_soon(executor.shutdown, False)
await thread_started_future
return thread_complete

def force_stop(self):
if self.loop is None:
return

def cancel_tasks_and_stop_loop():
tasks = asyncio.all_tasks(loop=self.loop)

for task in tasks:
self.loop.call_soon_threadsafe(task.cancel)

gather = asyncio.gather(*tasks, return_exceptions=True)
gather.add_done_callback(
lambda _: self.loop.call_soon_threadsafe(self.loop.stop)
)

self.loop.call_soon_threadsafe(cancel_tasks_and_stop_loop)


class ThreadsafeProxy:
"""Proxy class which enforces threadsafe non-blocking calls
This class can be used to wrap an object to ensure any calls
using that object's methods are done on a particular event loop
"""

def __init__(self, obj, obj_loop):
self._obj = obj
self._obj_loop = obj_loop

def __getattr__(self, name):
func = getattr(self._obj, name)
if not callable(func):
raise TypeError(
"Can only use ThreadsafeProxy with callable attributes: {}.{}".format(
self._obj.__class__.__name__, name
)
)

def func_wrapper(*args, **kwargs):
loop = self._obj_loop
curr_loop = asyncio.get_event_loop()
call = functools.partial(func, *args, **kwargs)
if loop == curr_loop:
return call()
if loop.is_closed():
# Disconnected
LOGGER.warning("Attempted to use a closed event loop")
return
if asyncio.iscoroutinefunction(func):
future = asyncio.run_coroutine_threadsafe(call(), loop)
return asyncio.wrap_future(future, loop=curr_loop)
else:

def check_result_wrapper():
result = call()
if result is not None:
raise TypeError(
(
"ThreadsafeProxy can only wrap functions with no return"
"value \nUse an async method to return values: {}.{}"
).format(self._obj.__class__.__name__, name)
)

loop.call_soon_threadsafe(check_result_wrapper)

return func_wrapper
32 changes: 31 additions & 1 deletion zigpy_znp/uart.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import zigpy_znp.frames as frames
import zigpy_znp.logger as log
from zigpy_znp.types import Bytes
from zigpy_znp.thread import EventLoopThread, ThreadsafeProxy
from zigpy_znp.exceptions import InvalidFrame

LOGGER = logging.getLogger(__name__)
Expand All @@ -25,6 +26,7 @@ def __init__(self, api, *, url: str | None = None) -> None:
self._api = api
self._transport = None
self._connected_event = asyncio.Event()
self._connection_done_event = asyncio.Event()

self.url = url

Expand All @@ -46,6 +48,9 @@ def connection_lost(self, exc: Exception | None) -> None:
if exc is not None:
LOGGER.warning("Lost connection", exc_info=exc)

if self._connection_done_event:
self._connection_done_event.set()

if self._api is not None:
self._api.connection_lost(exc)

Expand Down Expand Up @@ -157,8 +162,11 @@ def __repr__(self) -> str:
f">"
)

async def get_url(self):
return self.url


async def connect(config: conf.ConfigType, api) -> ZnpMtProtocol:
async def _connect(config: conf.ConfigType, api) -> ZnpMtProtocol:
loop = asyncio.get_running_loop()

port = config[conf.CONF_DEVICE_PATH]
Expand All @@ -181,3 +189,25 @@ async def connect(config: conf.ConfigType, api) -> ZnpMtProtocol:
LOGGER.debug("Connected to %s at %s baud", port, baudrate)

return protocol


async def connect(
config: conf.ConfigType, api, use_thread=True
) -> ZnpMtProtocol | ThreadsafeProxy:
if use_thread:
application = ThreadsafeProxy(api, asyncio.get_event_loop())
thread = EventLoopThread()
await thread.start()
try:
protocol = await thread.run_coroutine_threadsafe(
_connect(config, application)
)
except Exception:
thread.force_stop()
raise

thread_safe_protocol = ThreadsafeProxy(protocol, thread.loop)
return thread_safe_protocol
else:
protocol = await _connect(config, api)
return protocol