From 7a45132751bff98174688d88036ccb6d6c698c7a Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Fri, 8 Mar 2024 00:46:26 +0100 Subject: [PATCH] Support coroutines for ErrorCallback Allow to pass coroutines to the ErrorCallback. Instead of using a regular Python function make use of asyncio by default as well. --- src/controller/python/chip/clusters/Attribute.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/controller/python/chip/clusters/Attribute.py b/src/controller/python/chip/clusters/Attribute.py index 15738e27a00890..fe62f3f11d2f81 100644 --- a/src/controller/python/chip/clusters/Attribute.py +++ b/src/controller/python/chip/clusters/Attribute.py @@ -557,9 +557,9 @@ def SetEventUpdateCallback(self, callback: Callable[[EventReadResult, Subscripti if callback is not None: self._onEventChangeCb = callback - def SetErrorCallback(self, callback: Callable[[int, SubscriptionTransaction], None]): + def SetErrorCallback(self, callback: Callable[[int, SubscriptionTransaction], Union[None, Awaitable[None]]]): ''' - Sets the callback function in case a subscription error occured, + Sets the callback function in case a subscription error occurred, accepts a Callable accepts an error code and the cached data. ''' if callback is not None: @@ -574,7 +574,7 @@ def OnEventChangeCb(self) -> Callable[[EventReadResult, SubscriptionTransaction] return self._onEventChangeCb @property - def OnErrorCb(self) -> Callable[[int, SubscriptionTransaction], None]: + def OnErrorCb(self) -> Callable[[int, SubscriptionTransaction], Union[None, Awaitable[None]]]: return self._onErrorCb @property @@ -788,7 +788,7 @@ async def _handleReportEnd(self): # Clear it out once we've notified of all changes in this transaction. self._changedPathSet = set() - def _handleDone(self): + async def _handleDone(self): # # We only set the exception/result on the future in this _handleDone call (if it hasn't # already been set yet, which can be in the case of subscriptions) since doing so earlier @@ -798,7 +798,10 @@ def _handleDone(self): if not self._future.done(): if self._resultError: if self._subscription_handler: - self._subscription_handler.OnErrorCb(self._resultError, self._subscription_handler) + if inspect.iscoroutinefunction(self._subscription_handler.OnErrorCb): + await self._subscription_handler.OnErrorCb(self._resultError, self._subscription_handler) + else: + self._subscription_handler.OnErrorCb(self._resultError, self._subscription_handler) else: self._future.set_exception(chip.exceptions.ChipStackError(self._resultError)) else: @@ -813,7 +816,7 @@ def _handleDone(self): ctypes.pythonapi.Py_DecRef(ctypes.py_object(self)) def handleDone(self): - self._event_loop.call_soon_threadsafe(self._handleDone) + asyncio.run_coroutine_threadsafe(self._handleDone(), self._event_loop) def handleReportBegin(self): pass