diff --git a/src/controller/python/chip/ChipDeviceCtrl.py b/src/controller/python/chip/ChipDeviceCtrl.py index 369260787d9af8..0f1809c235be56 100644 --- a/src/controller/python/chip/ChipDeviceCtrl.py +++ b/src/controller/python/chip/ChipDeviceCtrl.py @@ -823,6 +823,56 @@ def deviceAvailable(self, device, err): return DeviceProxyWrapper(returnDevice, self._dmLib) + async def GetConnectedDevice(self, nodeid, allowPASE=True, timeoutMs: int = None): + ''' Returns DeviceProxyWrapper upon success.''' + self.CheckIsActive() + + returnDevice = c_void_p(None) + + if allowPASE: + res = self._ChipStack.Call(lambda: self._dmLib.pychip_GetDeviceBeingCommissioned( + self.devCtrl, nodeid, byref(returnDevice)), timeoutMs) + if res.is_success: + logging.info('Using PASE connection') + return DeviceProxyWrapper(returnDevice) + + eventLoop = asyncio.get_running_loop() + future = eventLoop.create_future() + + class DeviceAvailableClosure(): + def __init__(self, future: asyncio.Future): + self.returnDevice = c_void_p(None) + self.returnErr = None + self.future = future + + def _deviceAvailable(self): + if self.returnDevice.value is not None: + self.future.set_result(self.returnDevice) + else: + self.future.set_exception(self.returnErr.to_exception()) + + def deviceAvailable(self, device, err): + self.returnDevice = c_void_p(device) + self.returnErr = err + eventLoop.call_soon_threadsafe(self._deviceAvailable) + ctypes.pythonapi.Py_DecRef(ctypes.py_object(self)) + + closure = DeviceAvailableClosure(future) + ctypes.pythonapi.Py_IncRef(ctypes.py_object(closure)) + self._ChipStack.Call(lambda: self._dmLib.pychip_GetConnectedDeviceByNodeId( + self.devCtrl, nodeid, ctypes.py_object(closure), _DeviceAvailableCallback), + timeoutMs).raise_on_error() + + # The callback might have been received synchronously (during self._ChipStack.Call()). + # In that case the Future has already been set it will return immediately + if (timeoutMs): + timeout = float(timeoutMs) / 1000 + await asyncio.wait_for(future, timeout=timeout) + else: + await future + + return DeviceProxyWrapper(future.result(), self._dmLib) + def ComputeRoundTripTimeout(self, nodeid, upperLayerProcessingTimeoutMs: int = 0): ''' Returns a computed timeout value based on the round-trip time it takes for the peer at the other end of the session to receive a message, process it and send it back. This is computed based on the session type, the type of transport, @@ -887,7 +937,7 @@ async def TestOnlySendBatchCommands(self, nodeid: int, commands: typing.List[Clu eventLoop = asyncio.get_running_loop() future = eventLoop.create_future() - device = self.GetConnectedDeviceSync(nodeid, timeoutMs=interactionTimeoutMs) + device = await self.GetConnectedDevice(nodeid, timeoutMs=interactionTimeoutMs) ClusterCommand.TestOnlySendBatchCommands( future, eventLoop, device.deviceProxy, commands, @@ -908,7 +958,7 @@ async def TestOnlySendCommandTimedRequestFlagWithNoTimedInvoke(self, nodeid: int eventLoop = asyncio.get_running_loop() future = eventLoop.create_future() - device = self.GetConnectedDeviceSync(nodeid, timeoutMs=None) + device = await self.GetConnectedDevice(nodeid, timeoutMs=None) ClusterCommand.TestOnlySendCommandTimedRequestFlagWithNoTimedInvoke( future, eventLoop, responseType, device.deviceProxy, ClusterCommand.CommandPath( EndpointId=endpoint, @@ -940,7 +990,7 @@ async def SendCommand(self, nodeid: int, endpoint: int, payload: ClusterObjects. eventLoop = asyncio.get_running_loop() future = eventLoop.create_future() - device = self.GetConnectedDeviceSync(nodeid, timeoutMs=interactionTimeoutMs) + device = await self.GetConnectedDevice(nodeid, timeoutMs=interactionTimeoutMs) ClusterCommand.SendCommand( future, eventLoop, responseType, device.deviceProxy, ClusterCommand.CommandPath( EndpointId=endpoint, @@ -981,7 +1031,7 @@ async def SendBatchCommands(self, nodeid: int, commands: typing.List[ClusterComm eventLoop = asyncio.get_running_loop() future = eventLoop.create_future() - device = self.GetConnectedDeviceSync(nodeid, timeoutMs=interactionTimeoutMs) + device = await self.GetConnectedDevice(nodeid, timeoutMs=interactionTimeoutMs) ClusterCommand.SendBatchCommands( future, eventLoop, device.deviceProxy, commands, @@ -1031,7 +1081,7 @@ async def WriteAttribute(self, nodeid: int, eventLoop = asyncio.get_running_loop() future = eventLoop.create_future() - device = self.GetConnectedDeviceSync(nodeid, timeoutMs=interactionTimeoutMs) + device = await self.GetConnectedDevice(nodeid, timeoutMs=interactionTimeoutMs) attrs = [] for v in attributes: @@ -1259,7 +1309,7 @@ async def Read(self, nodeid: int, attributes: typing.List[typing.Union[ eventLoop = asyncio.get_running_loop() future = eventLoop.create_future() - device = self.GetConnectedDeviceSync(nodeid) + device = await self.GetConnectedDevice(nodeid) attributePaths = [self._parseAttributePathTuple( v) for v in attributes] if attributes else None clusterDataVersionFilters = [self._parseDataVersionFilterTuple(