From 8729a25c195f1592a54e9aff6802c46388d154aa Mon Sep 17 00:00:00 2001 From: AlexWells Date: Wed, 5 Apr 2023 16:01:38 +0100 Subject: [PATCH] Raise exceptions when server disconnects Now if you try and send data, either the write or the read can time out if you provide an optional timeout value --- pandablocks/asyncio.py | 24 +++++++++++++++++++----- tests/test_asyncio.py | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 5 deletions(-) diff --git a/pandablocks/asyncio.py b/pandablocks/asyncio.py index f64dea818..0f93126e9 100644 --- a/pandablocks/asyncio.py +++ b/pandablocks/asyncio.py @@ -26,10 +26,17 @@ def writer(self) -> StreamWriter: assert self._writer, "connect() not called yet" return self._writer - async def write_and_drain(self, data: bytes): + async def write_and_drain(self, data: bytes, timeout: Optional[float] = None): writer = self.writer writer.write(data) - await writer.drain() + + # Cannot simply await the drain, as if the remote end has disconnected + # then the drain will never complete as the OS cannot clear its send buffer. + _, pending = await asyncio.wait([writer.drain()], timeout=timeout) + if len(pending): + for task in pending: + task.cancel() + raise asyncio.TimeoutError("Timeout writing data") async def connect(self, host: str, port: int): self._reader, self._writer = await asyncio.open_connection(host, port) @@ -70,6 +77,13 @@ async def connect(self): self._ctrl_read_forever(self._ctrl_stream.reader) ) + def is_connected(self): + """True if there is a currently active connection. + NOTE: This does not indicate if the remote end is still connected.""" + if self._ctrl_task and not self._ctrl_task.done(): + return True + return False + async def close(self): """Close the control connection, and wait for completion""" assert self._ctrl_task, "connect() not called yet" @@ -95,7 +109,7 @@ async def _ctrl_read_forever(self, reader: asyncio.StreamReader): except Exception: logging.exception(f"Error handling '{received.decode()}'") - async def send(self, command: Command[T]) -> T: + async def send(self, command: Command[T], timeout: Optional[float] = None) -> T: """Send a command to control port of the PandA, returning its response. Args: @@ -105,8 +119,8 @@ async def send(self, command: Command[T]) -> T: # Need to use the id as non-frozen dataclasses don't hash self._ctrl_queues[id(command)] = queue to_send = self._ctrl_connection.send(command) - await self._ctrl_stream.write_and_drain(to_send) - response = await queue.get() + await self._ctrl_stream.write_and_drain(to_send, timeout) + response = await asyncio.wait_for(queue.get(), timeout) if isinstance(response, Exception): raise response else: diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 20043f789..ff52f006c 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -5,6 +5,11 @@ from pandablocks.asyncio import AsyncioClient from pandablocks.commands import CommandException, Get, Put +from .conftest import DummyServer + +# Timeout in seconds +TIMEOUT = 3 + @pytest.mark.asyncio async def test_asyncio_get(dummy_server_async): @@ -35,3 +40,39 @@ async def test_asyncio_data(dummy_server_async, fast_dump, fast_dump_expected): if len(events) == 9: break assert fast_dump_expected == events + + +@pytest.mark.asyncio +async def test_asyncio_connects(dummy_server_async: DummyServer): + async with AsyncioClient("localhost") as client: + assert client.is_connected() + + assert not client.is_connected() + + +@pytest.mark.asyncio +async def test_asyncio_client_uncontactable(): + """Test that a client raises an exception when the remote end is not + contactable""" + client = AsyncioClient("localhost") + with pytest.raises(OSError): + await client.connect() + + +@pytest.mark.asyncio +async def test_asyncio_client_fails_when_cannot_drain(dummy_server_async: DummyServer): + """Test that we don't hang indefinitely when failing to drain data from the OS + send buffer""" + + # Note this value is probably OS-dependant. I found it experimentally. + large_data = b"ABC" * 100000000 + + client = AsyncioClient("localhost") + await client.connect() + await dummy_server_async.close() + with pytest.raises(asyncio.TimeoutError): + await client._ctrl_stream.write_and_drain(large_data, timeout=TIMEOUT) + + # Can't use client.close() as it gets endlessly stuck. Do the important part. + assert client._ctrl_task + client._ctrl_task.cancel()