Skip to content

Commit

Permalink
[Python] Implement async friendly GetConnectedDevice
Browse files Browse the repository at this point in the history
Currently GetConnectedDeviceSync() is blocking e.g. when a new session
needs to be created. This is not asyncio friendly as it blocks the
whole event loop.

Implement a asyncio friendly variant GetConnectedDevice() which is
a co-routine function which can be awaited.
  • Loading branch information
agners committed Mar 27, 2024
1 parent 60b6beb commit 93ed949
Showing 1 changed file with 56 additions and 6 deletions.
62 changes: 56 additions & 6 deletions src/controller/python/chip/ChipDeviceCtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,56 @@ def deviceAvailable(self, device, err):

return DeviceProxyWrapper(returnDevice, self._dmLib)

async def GetConnectedDevice(self, nodeid, allowPASE=True, timeoutMs: int = None):
''' Returns DeviceProxyWrapper upon success.'''
self.CheckIsActive()

if allowPASE:
returnDevice = c_void_p(None)
res = self._ChipStack.Call(lambda: self._dmLib.pychip_GetDeviceBeingCommissioned(
self.devCtrl, nodeid, byref(returnDevice)), timeoutMs)
if res.is_success:
logging.info('Using PASE connection')
return DeviceProxyWrapper(returnDevice)

eventLoop = asyncio.get_running_loop()
future = eventLoop.create_future()

class DeviceAvailableClosure():
def __init__(self, loop, future: asyncio.Future):
self._returnDevice = c_void_p(None)
self._returnErr = None
self._event_loop = loop
self._future = future

def _deviceAvailable(self):
if self._returnDevice.value is not None:
self._future.set_result(self._returnDevice)
else:
self._future.set_exception(self._returnErr.to_exception())

def deviceAvailable(self, device, err):
self._returnDevice = c_void_p(device)
self._returnErr = err
self._event_loop.call_soon_threadsafe(self._deviceAvailable)
ctypes.pythonapi.Py_DecRef(ctypes.py_object(self))

closure = DeviceAvailableClosure(eventLoop, future)
ctypes.pythonapi.Py_IncRef(ctypes.py_object(closure))
self._ChipStack.Call(lambda: self._dmLib.pychip_GetConnectedDeviceByNodeId(
self.devCtrl, nodeid, ctypes.py_object(closure), _DeviceAvailableCallback),
timeoutMs).raise_on_error()

# The callback might have been received synchronously (during self._ChipStack.Call()).
# In that case the Future has already been set it will return immediately
if (timeoutMs):
timeout = float(timeoutMs) / 1000
await asyncio.wait_for(future, timeout=timeout)
else:
await future

return DeviceProxyWrapper(future.result(), self._dmLib)

def ComputeRoundTripTimeout(self, nodeid, upperLayerProcessingTimeoutMs: int = 0):
''' Returns a computed timeout value based on the round-trip time it takes for the peer at the other end of the session to
receive a message, process it and send it back. This is computed based on the session type, the type of transport,
Expand Down Expand Up @@ -887,7 +937,7 @@ async def TestOnlySendBatchCommands(self, nodeid: int, commands: typing.List[Clu
eventLoop = asyncio.get_running_loop()
future = eventLoop.create_future()

device = self.GetConnectedDeviceSync(nodeid, timeoutMs=interactionTimeoutMs)
device = await self.GetConnectedDevice(nodeid, timeoutMs=interactionTimeoutMs)

ClusterCommand.TestOnlySendBatchCommands(
future, eventLoop, device.deviceProxy, commands,
Expand All @@ -908,7 +958,7 @@ async def TestOnlySendCommandTimedRequestFlagWithNoTimedInvoke(self, nodeid: int
eventLoop = asyncio.get_running_loop()
future = eventLoop.create_future()

device = self.GetConnectedDeviceSync(nodeid, timeoutMs=None)
device = await self.GetConnectedDevice(nodeid, timeoutMs=None)
ClusterCommand.TestOnlySendCommandTimedRequestFlagWithNoTimedInvoke(
future, eventLoop, responseType, device.deviceProxy, ClusterCommand.CommandPath(
EndpointId=endpoint,
Expand Down Expand Up @@ -940,7 +990,7 @@ async def SendCommand(self, nodeid: int, endpoint: int, payload: ClusterObjects.
eventLoop = asyncio.get_running_loop()
future = eventLoop.create_future()

device = self.GetConnectedDeviceSync(nodeid, timeoutMs=interactionTimeoutMs)
device = await self.GetConnectedDevice(nodeid, timeoutMs=interactionTimeoutMs)
ClusterCommand.SendCommand(
future, eventLoop, responseType, device.deviceProxy, ClusterCommand.CommandPath(
EndpointId=endpoint,
Expand Down Expand Up @@ -981,7 +1031,7 @@ async def SendBatchCommands(self, nodeid: int, commands: typing.List[ClusterComm
eventLoop = asyncio.get_running_loop()
future = eventLoop.create_future()

device = self.GetConnectedDeviceSync(nodeid, timeoutMs=interactionTimeoutMs)
device = await self.GetConnectedDevice(nodeid, timeoutMs=interactionTimeoutMs)

ClusterCommand.SendBatchCommands(
future, eventLoop, device.deviceProxy, commands,
Expand Down Expand Up @@ -1031,7 +1081,7 @@ async def WriteAttribute(self, nodeid: int,
eventLoop = asyncio.get_running_loop()
future = eventLoop.create_future()

device = self.GetConnectedDeviceSync(nodeid, timeoutMs=interactionTimeoutMs)
device = await self.GetConnectedDevice(nodeid, timeoutMs=interactionTimeoutMs)

attrs = []
for v in attributes:
Expand Down Expand Up @@ -1259,7 +1309,7 @@ async def Read(self, nodeid: int, attributes: typing.List[typing.Union[
eventLoop = asyncio.get_running_loop()
future = eventLoop.create_future()

device = self.GetConnectedDeviceSync(nodeid)
device = await self.GetConnectedDevice(nodeid)
attributePaths = [self._parseAttributePathTuple(
v) for v in attributes] if attributes else None
clusterDataVersionFilters = [self._parseDataVersionFilterTuple(
Expand Down

0 comments on commit 93ed949

Please sign in to comment.