Skip to content

Commit

Permalink
Raise exceptions when server disconnects
Browse files Browse the repository at this point in the history
Now if you try and send data, either the write or the read can time out
if you provide an optional timeout value
  • Loading branch information
AlexanderWells-diamond committed Apr 5, 2023
1 parent 5dc0999 commit 8729a25
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 5 deletions.
24 changes: 19 additions & 5 deletions pandablocks/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,17 @@ def writer(self) -> StreamWriter:
assert self._writer, "connect() not called yet"
return self._writer

async def write_and_drain(self, data: bytes):
async def write_and_drain(self, data: bytes, timeout: Optional[float] = None):
writer = self.writer
writer.write(data)
await writer.drain()

# Cannot simply await the drain, as if the remote end has disconnected
# then the drain will never complete as the OS cannot clear its send buffer.
_, pending = await asyncio.wait([writer.drain()], timeout=timeout)
if len(pending):
for task in pending:
task.cancel()
raise asyncio.TimeoutError("Timeout writing data")

async def connect(self, host: str, port: int):
self._reader, self._writer = await asyncio.open_connection(host, port)
Expand Down Expand Up @@ -70,6 +77,13 @@ async def connect(self):
self._ctrl_read_forever(self._ctrl_stream.reader)
)

def is_connected(self):
"""True if there is a currently active connection.
NOTE: This does not indicate if the remote end is still connected."""
if self._ctrl_task and not self._ctrl_task.done():
return True
return False

async def close(self):
"""Close the control connection, and wait for completion"""
assert self._ctrl_task, "connect() not called yet"
Expand All @@ -95,7 +109,7 @@ async def _ctrl_read_forever(self, reader: asyncio.StreamReader):
except Exception:
logging.exception(f"Error handling '{received.decode()}'")

async def send(self, command: Command[T]) -> T:
async def send(self, command: Command[T], timeout: Optional[float] = None) -> T:
"""Send a command to control port of the PandA, returning its response.
Args:
Expand All @@ -105,8 +119,8 @@ async def send(self, command: Command[T]) -> T:
# Need to use the id as non-frozen dataclasses don't hash
self._ctrl_queues[id(command)] = queue
to_send = self._ctrl_connection.send(command)
await self._ctrl_stream.write_and_drain(to_send)
response = await queue.get()
await self._ctrl_stream.write_and_drain(to_send, timeout)
response = await asyncio.wait_for(queue.get(), timeout)
if isinstance(response, Exception):
raise response
else:
Expand Down
41 changes: 41 additions & 0 deletions tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from pandablocks.asyncio import AsyncioClient
from pandablocks.commands import CommandException, Get, Put

from .conftest import DummyServer

# Timeout in seconds
TIMEOUT = 3


@pytest.mark.asyncio
async def test_asyncio_get(dummy_server_async):
Expand Down Expand Up @@ -35,3 +40,39 @@ async def test_asyncio_data(dummy_server_async, fast_dump, fast_dump_expected):
if len(events) == 9:
break
assert fast_dump_expected == events


@pytest.mark.asyncio
async def test_asyncio_connects(dummy_server_async: DummyServer):
async with AsyncioClient("localhost") as client:
assert client.is_connected()

assert not client.is_connected()


@pytest.mark.asyncio
async def test_asyncio_client_uncontactable():
"""Test that a client raises an exception when the remote end is not
contactable"""
client = AsyncioClient("localhost")
with pytest.raises(OSError):
await client.connect()


@pytest.mark.asyncio
async def test_asyncio_client_fails_when_cannot_drain(dummy_server_async: DummyServer):
"""Test that we don't hang indefinitely when failing to drain data from the OS
send buffer"""

# Note this value is probably OS-dependant. I found it experimentally.
large_data = b"ABC" * 100000000

client = AsyncioClient("localhost")
await client.connect()
await dummy_server_async.close()
with pytest.raises(asyncio.TimeoutError):
await client._ctrl_stream.write_and_drain(large_data, timeout=TIMEOUT)

# Can't use client.close() as it gets endlessly stuck. Do the important part.
assert client._ctrl_task
client._ctrl_task.cancel()

0 comments on commit 8729a25

Please sign in to comment.