From e19d68c76c621b98ec853d627befa867fc0ca5d1 Mon Sep 17 00:00:00 2001 From: Eva Lott Date: Mon, 8 Apr 2024 09:13:56 +0100 Subject: [PATCH] added a break condition to `_ctrl_read_forever` --- src/pandablocks/asyncio.py | 8 ++++++++ tests/test_asyncio.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/src/pandablocks/asyncio.py b/src/pandablocks/asyncio.py index b420f5b6..f7f9bf42 100644 --- a/src/pandablocks/asyncio.py +++ b/src/pandablocks/asyncio.py @@ -99,8 +99,16 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() async def _ctrl_read_forever(self, reader: asyncio.StreamReader): + """Continually read data from the stream reader and add to the data queue. + + Args: + reader: The `StreamReader` to read from + """ while True: received = await reader.read(4096) + if received == b"": + raise ConnectionError("Received an empty packet. Closing connection.") + try: to_send = self._ctrl_connection.receive_bytes(received) await self._ctrl_stream.write_and_drain(to_send) diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index dc54e1bb..b63a06ee 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -104,6 +104,37 @@ async def test_asyncio_data_timeout(dummy_server_async, fast_dump): "This goes forever, when it runs out of data we will get our timeout" +async def test_asyncio_empty_frame_error(): + dummy_data = [b"ABC"] * 10 + [b""] + dummy_data_iter = iter(dummy_data) + + async def dummy_read(n): + return dummy_data_iter.__next__() + + reader = asyncio.StreamReader() + reader.read = dummy_read + + written = [] + + class DummyControlStream: + async def write_and_drain(self, data): + written.append(data) + + class DummyControlConnection: + def receive_bytes(self, data): + return data + + client = AsyncioClient("localhost") + client._ctrl_stream = DummyControlStream() + client._ctrl_connection = DummyControlConnection() + with pytest.raises( + ConnectionError, match="Received an empty packet. Closing connection." + ): + await client._ctrl_read_forever(reader) + assert written == dummy_data[:-1] + client.close() + + @pytest.mark.asyncio async def test_asyncio_connects(dummy_server_async: DummyServer): async with AsyncioClient("localhost") as client: