Skip to content

Commit

Permalink
Support coroutines for ErrorCallback
Browse files Browse the repository at this point in the history
Allow to pass coroutines to the ErrorCallback. Instead of using a
regular Python function make use of asyncio by default as well.
  • Loading branch information
agners committed Mar 8, 2024
1 parent c34701a commit 7a45132
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions src/controller/python/chip/clusters/Attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,9 +557,9 @@ def SetEventUpdateCallback(self, callback: Callable[[EventReadResult, Subscripti
if callback is not None:
self._onEventChangeCb = callback

def SetErrorCallback(self, callback: Callable[[int, SubscriptionTransaction], None]):
def SetErrorCallback(self, callback: Callable[[int, SubscriptionTransaction], Union[None, Awaitable[None]]]):
'''
Sets the callback function in case a subscription error occured,
Sets the callback function in case a subscription error occurred,
accepts a Callable accepts an error code and the cached data.
'''
if callback is not None:
Expand All @@ -574,7 +574,7 @@ def OnEventChangeCb(self) -> Callable[[EventReadResult, SubscriptionTransaction]
return self._onEventChangeCb

@property
def OnErrorCb(self) -> Callable[[int, SubscriptionTransaction], None]:
def OnErrorCb(self) -> Callable[[int, SubscriptionTransaction], Union[None, Awaitable[None]]]:
return self._onErrorCb

@property
Expand Down Expand Up @@ -788,7 +788,7 @@ async def _handleReportEnd(self):
# Clear it out once we've notified of all changes in this transaction.
self._changedPathSet = set()

def _handleDone(self):
async def _handleDone(self):
#
# We only set the exception/result on the future in this _handleDone call (if it hasn't
# already been set yet, which can be in the case of subscriptions) since doing so earlier
Expand All @@ -798,7 +798,10 @@ def _handleDone(self):
if not self._future.done():
if self._resultError:
if self._subscription_handler:
self._subscription_handler.OnErrorCb(self._resultError, self._subscription_handler)
if inspect.iscoroutinefunction(self._subscription_handler.OnErrorCb):
await self._subscription_handler.OnErrorCb(self._resultError, self._subscription_handler)
else:
self._subscription_handler.OnErrorCb(self._resultError, self._subscription_handler)
else:
self._future.set_exception(chip.exceptions.ChipStackError(self._resultError))
else:
Expand All @@ -813,7 +816,7 @@ def _handleDone(self):
ctypes.pythonapi.Py_DecRef(ctypes.py_object(self))

def handleDone(self):
self._event_loop.call_soon_threadsafe(self._handleDone)
asyncio.run_coroutine_threadsafe(self._handleDone(), self._event_loop)

def handleReportBegin(self):
pass
Expand Down

0 comments on commit 7a45132

Please sign in to comment.